1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR OpenMP dialect and LLVM 10 // IR. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" 14 #include "mlir/Analysis/TopologicalSortUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 17 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" 18 #include "mlir/IR/IRMapping.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h" 22 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 23 #include "mlir/Transforms/RegionUtils.h" 24 25 #include "llvm/ADT/ArrayRef.h" 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Frontend/OpenMP/OMPConstants.h" 29 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 30 #include "llvm/IR/DebugInfoMetadata.h" 31 #include "llvm/IR/IRBuilder.h" 32 #include "llvm/IR/ReplaceConstant.h" 33 #include "llvm/Support/FileSystem.h" 34 #include "llvm/TargetParser/Triple.h" 35 #include "llvm/Transforms/Utils/ModuleUtils.h" 36 37 #include <any> 38 #include <cstdint> 39 #include <iterator> 40 #include <numeric> 41 #include <optional> 42 #include <utility> 43 44 using namespace mlir; 45 46 namespace { 47 static llvm::omp::ScheduleKind 48 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) { 49 if (!schedKind.has_value()) 50 return llvm::omp::OMP_SCHEDULE_Default; 51 switch (schedKind.value()) { 52 case omp::ClauseScheduleKind::Static: 53 return llvm::omp::OMP_SCHEDULE_Static; 54 case omp::ClauseScheduleKind::Dynamic: 55 return llvm::omp::OMP_SCHEDULE_Dynamic; 56 case omp::ClauseScheduleKind::Guided: 57 return llvm::omp::OMP_SCHEDULE_Guided; 58 case omp::ClauseScheduleKind::Auto: 59 return llvm::omp::OMP_SCHEDULE_Auto; 60 case omp::ClauseScheduleKind::Runtime: 61 return llvm::omp::OMP_SCHEDULE_Runtime; 62 } 63 llvm_unreachable("unhandled schedule clause argument"); 64 } 65 66 /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the 67 /// insertion points for allocas. 68 class OpenMPAllocaStackFrame 69 : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> { 70 public: 71 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame) 72 73 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP) 74 : allocaInsertPoint(allocaIP) {} 75 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint; 76 }; 77 78 /// ModuleTranslation stack frame containing the partial mapping between MLIR 79 /// values and their LLVM IR equivalents. 80 class OpenMPVarMappingStackFrame 81 : public LLVM::ModuleTranslation::StackFrameBase< 82 OpenMPVarMappingStackFrame> { 83 public: 84 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPVarMappingStackFrame) 85 86 explicit OpenMPVarMappingStackFrame( 87 const DenseMap<Value, llvm::Value *> &mapping) 88 : mapping(mapping) {} 89 90 DenseMap<Value, llvm::Value *> mapping; 91 }; 92 93 /// Custom error class to signal translation errors that don't need reporting, 94 /// since encountering them will have already triggered relevant error messages. 95 /// 96 /// Its purpose is to serve as the glue between MLIR failures represented as 97 /// \see LogicalResult instances and \see llvm::Error instances used to 98 /// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an 99 /// error of the first type is raised, a message is emitted directly (the \see 100 /// LogicalResult itself does not hold any information). If we need to forward 101 /// this error condition as an \see llvm::Error while avoiding triggering some 102 /// redundant error reporting later on, we need a custom \see llvm::ErrorInfo 103 /// class to just signal this situation has happened. 104 /// 105 /// For example, this class should be used to trigger errors from within 106 /// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the 107 /// translation of their own regions. This unclutters the error log from 108 /// redundant messages. 109 class PreviouslyReportedError 110 : public llvm::ErrorInfo<PreviouslyReportedError> { 111 public: 112 void log(raw_ostream &) const override { 113 // Do not log anything. 114 } 115 116 std::error_code convertToErrorCode() const override { 117 llvm_unreachable( 118 "PreviouslyReportedError doesn't support ECError conversion"); 119 } 120 121 // Used by ErrorInfo::classID. 122 static char ID; 123 }; 124 125 char PreviouslyReportedError::ID = 0; 126 127 } // namespace 128 129 /// Looks up from the operation from and returns the PrivateClauseOp with 130 /// name symbolName 131 static omp::PrivateClauseOp findPrivatizer(Operation *from, 132 SymbolRefAttr symbolName) { 133 omp::PrivateClauseOp privatizer = 134 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from, 135 symbolName); 136 assert(privatizer && "privatizer not found in the symbol table"); 137 return privatizer; 138 } 139 140 /// Check whether translation to LLVM IR for the given operation is currently 141 /// supported. If not, descriptive diagnostics will be emitted to let users know 142 /// this is a not-yet-implemented feature. 143 /// 144 /// \returns success if no unimplemented features are needed to translate the 145 /// given operation. 146 static LogicalResult checkImplementationStatus(Operation &op) { 147 auto todo = [&op](StringRef clauseName) { 148 return op.emitError() << "not yet implemented: Unhandled clause " 149 << clauseName << " in " << op.getName() 150 << " operation"; 151 }; 152 153 auto checkAllocate = [&todo](auto op, LogicalResult &result) { 154 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty()) 155 result = todo("allocate"); 156 }; 157 auto checkBare = [&todo](auto op, LogicalResult &result) { 158 if (op.getBare()) 159 result = todo("ompx_bare"); 160 }; 161 auto checkDepend = [&todo](auto op, LogicalResult &result) { 162 if (!op.getDependVars().empty() || op.getDependKinds()) 163 result = todo("depend"); 164 }; 165 auto checkDevice = [&todo](auto op, LogicalResult &result) { 166 if (op.getDevice()) 167 result = todo("device"); 168 }; 169 auto checkHasDeviceAddr = [&todo](auto op, LogicalResult &result) { 170 if (!op.getHasDeviceAddrVars().empty()) 171 result = todo("has_device_addr"); 172 }; 173 auto checkHint = [](auto op, LogicalResult &) { 174 if (op.getHint()) 175 op.emitWarning("hint clause discarded"); 176 }; 177 auto checkHostEval = [](auto op, LogicalResult &result) { 178 // Host evaluated clauses are supported, except for loop bounds. 179 for (BlockArgument arg : 180 cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs()) 181 for (Operation *user : arg.getUsers()) 182 if (isa<omp::LoopNestOp>(user)) 183 result = op.emitError("not yet implemented: host evaluation of loop " 184 "bounds in omp.target operation"); 185 }; 186 auto checkInReduction = [&todo](auto op, LogicalResult &result) { 187 if (!op.getInReductionVars().empty() || op.getInReductionByref() || 188 op.getInReductionSyms()) 189 result = todo("in_reduction"); 190 }; 191 auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) { 192 if (!op.getIsDevicePtrVars().empty()) 193 result = todo("is_device_ptr"); 194 }; 195 auto checkLinear = [&todo](auto op, LogicalResult &result) { 196 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty()) 197 result = todo("linear"); 198 }; 199 auto checkNontemporal = [&todo](auto op, LogicalResult &result) { 200 if (!op.getNontemporalVars().empty()) 201 result = todo("nontemporal"); 202 }; 203 auto checkNowait = [&todo](auto op, LogicalResult &result) { 204 if (op.getNowait()) 205 result = todo("nowait"); 206 }; 207 auto checkOrder = [&todo](auto op, LogicalResult &result) { 208 if (op.getOrder() || op.getOrderMod()) 209 result = todo("order"); 210 }; 211 auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) { 212 if (op.getParLevelSimd()) 213 result = todo("parallelization-level"); 214 }; 215 auto checkPriority = [&todo](auto op, LogicalResult &result) { 216 if (op.getPriority()) 217 result = todo("priority"); 218 }; 219 auto checkPrivate = [&todo](auto op, LogicalResult &result) { 220 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) { 221 // Privatization clauses are supported, except on some situations, so we 222 // need to check here whether any of these unsupported cases are being 223 // translated. 224 if (std::optional<ArrayAttr> privateSyms = op.getPrivateSyms()) { 225 for (Attribute privatizerNameAttr : *privateSyms) { 226 omp::PrivateClauseOp privatizer = findPrivatizer( 227 op.getOperation(), cast<SymbolRefAttr>(privatizerNameAttr)); 228 229 if (privatizer.getDataSharingType() == 230 omp::DataSharingClauseType::FirstPrivate) 231 result = todo("firstprivate"); 232 } 233 } 234 } else { 235 if (!op.getPrivateVars().empty() || op.getPrivateSyms()) 236 result = todo("privatization"); 237 } 238 }; 239 auto checkReduction = [&todo](auto op, LogicalResult &result) { 240 if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op)) 241 if (!op.getReductionVars().empty() || op.getReductionByref() || 242 op.getReductionSyms()) 243 result = todo("reduction"); 244 if (op.getReductionMod() && 245 op.getReductionMod().value() != omp::ReductionModifier::defaultmod) 246 result = todo("reduction with modifier"); 247 }; 248 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) { 249 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() || 250 op.getTaskReductionSyms()) 251 result = todo("task_reduction"); 252 }; 253 auto checkUntied = [&todo](auto op, LogicalResult &result) { 254 if (op.getUntied()) 255 result = todo("untied"); 256 }; 257 258 LogicalResult result = success(); 259 llvm::TypeSwitch<Operation &>(op) 260 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); }) 261 .Case([&](omp::SectionsOp op) { 262 checkAllocate(op, result); 263 checkPrivate(op, result); 264 checkReduction(op, result); 265 }) 266 .Case([&](omp::SingleOp op) { 267 checkAllocate(op, result); 268 checkPrivate(op, result); 269 }) 270 .Case([&](omp::TeamsOp op) { 271 checkAllocate(op, result); 272 checkPrivate(op, result); 273 checkReduction(op, result); 274 }) 275 .Case([&](omp::TaskOp op) { 276 checkAllocate(op, result); 277 checkInReduction(op, result); 278 }) 279 .Case([&](omp::TaskgroupOp op) { 280 checkAllocate(op, result); 281 checkTaskReduction(op, result); 282 }) 283 .Case([&](omp::TaskwaitOp op) { 284 checkDepend(op, result); 285 checkNowait(op, result); 286 }) 287 .Case([&](omp::TaskloopOp op) { 288 // TODO: Add other clauses check 289 checkUntied(op, result); 290 checkPriority(op, result); 291 }) 292 .Case([&](omp::WsloopOp op) { 293 checkAllocate(op, result); 294 checkLinear(op, result); 295 checkOrder(op, result); 296 checkReduction(op, result); 297 }) 298 .Case([&](omp::ParallelOp op) { 299 checkAllocate(op, result); 300 checkReduction(op, result); 301 }) 302 .Case([&](omp::SimdOp op) { 303 checkLinear(op, result); 304 checkNontemporal(op, result); 305 checkReduction(op, result); 306 }) 307 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp, 308 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); }) 309 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>( 310 [&](auto op) { checkDepend(op, result); }) 311 .Case([&](omp::TargetOp op) { 312 checkAllocate(op, result); 313 checkBare(op, result); 314 checkDevice(op, result); 315 checkHasDeviceAddr(op, result); 316 checkHostEval(op, result); 317 checkInReduction(op, result); 318 checkIsDevicePtr(op, result); 319 checkPrivate(op, result); 320 }) 321 .Default([](Operation &) { 322 // Assume all clauses for an operation can be translated unless they are 323 // checked above. 324 }); 325 return result; 326 } 327 328 static LogicalResult handleError(llvm::Error error, Operation &op) { 329 LogicalResult result = success(); 330 if (error) { 331 llvm::handleAllErrors( 332 std::move(error), 333 [&](const PreviouslyReportedError &) { result = failure(); }, 334 [&](const llvm::ErrorInfoBase &err) { 335 result = op.emitError(err.message()); 336 }); 337 } 338 return result; 339 } 340 341 template <typename T> 342 static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) { 343 if (!result) 344 return handleError(result.takeError(), op); 345 346 return success(); 347 } 348 349 /// Find the insertion point for allocas given the current insertion point for 350 /// normal operations in the builder. 351 static llvm::OpenMPIRBuilder::InsertPointTy 352 findAllocaInsertPoint(llvm::IRBuilderBase &builder, 353 const LLVM::ModuleTranslation &moduleTranslation) { 354 // If there is an alloca insertion point on stack, i.e. we are in a nested 355 // operation and a specific point was provided by some surrounding operation, 356 // use it. 357 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint; 358 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>( 359 [&](const OpenMPAllocaStackFrame &frame) { 360 allocaInsertPoint = frame.allocaInsertPoint; 361 return WalkResult::interrupt(); 362 }); 363 if (walkResult.wasInterrupted()) 364 return allocaInsertPoint; 365 366 // Otherwise, insert to the entry block of the surrounding function. 367 // If the current IRBuilder InsertPoint is the function's entry, it cannot 368 // also be used for alloca insertion which would result in insertion order 369 // confusion. Create a new BasicBlock for the Builder and use the entry block 370 // for the allocs. 371 // TODO: Create a dedicated alloca BasicBlock at function creation such that 372 // we do not need to move the current InertPoint here. 373 if (builder.GetInsertBlock() == 374 &builder.GetInsertBlock()->getParent()->getEntryBlock()) { 375 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() && 376 "Assuming end of basic block"); 377 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create( 378 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(), 379 builder.GetInsertBlock()->getNextNode()); 380 builder.CreateBr(entryBB); 381 builder.SetInsertPoint(entryBB); 382 } 383 384 llvm::BasicBlock &funcEntryBlock = 385 builder.GetInsertBlock()->getParent()->getEntryBlock(); 386 return llvm::OpenMPIRBuilder::InsertPointTy( 387 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt()); 388 } 389 390 /// Converts the given region that appears within an OpenMP dialect operation to 391 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the 392 /// region, and a branch from any block with an successor-less OpenMP terminator 393 /// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes 394 /// of the continuation block if provided. 395 static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions( 396 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, 397 LLVM::ModuleTranslation &moduleTranslation, 398 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) { 399 llvm::BasicBlock *continuationBlock = 400 splitBB(builder, true, "omp.region.cont"); 401 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock(); 402 403 llvm::LLVMContext &llvmContext = builder.getContext(); 404 for (Block &bb : region) { 405 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create( 406 llvmContext, blockName, builder.GetInsertBlock()->getParent(), 407 builder.GetInsertBlock()->getNextNode()); 408 moduleTranslation.mapBlock(&bb, llvmBB); 409 } 410 411 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator(); 412 413 // Terminators (namely YieldOp) may be forwarding values to the region that 414 // need to be available in the continuation block. Collect the types of these 415 // operands in preparation of creating PHI nodes. 416 SmallVector<llvm::Type *> continuationBlockPHITypes; 417 bool operandsProcessed = false; 418 unsigned numYields = 0; 419 for (Block &bb : region.getBlocks()) { 420 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) { 421 if (!operandsProcessed) { 422 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { 423 continuationBlockPHITypes.push_back( 424 moduleTranslation.convertType(yield->getOperand(i).getType())); 425 } 426 operandsProcessed = true; 427 } else { 428 assert(continuationBlockPHITypes.size() == yield->getNumOperands() && 429 "mismatching number of values yielded from the region"); 430 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { 431 llvm::Type *operandType = 432 moduleTranslation.convertType(yield->getOperand(i).getType()); 433 (void)operandType; 434 assert(continuationBlockPHITypes[i] == operandType && 435 "values of mismatching types yielded from the region"); 436 } 437 } 438 numYields++; 439 } 440 } 441 442 // Insert PHI nodes in the continuation block for any values forwarded by the 443 // terminators in this region. 444 if (!continuationBlockPHITypes.empty()) 445 assert( 446 continuationBlockPHIs && 447 "expected continuation block PHIs if converted regions yield values"); 448 if (continuationBlockPHIs) { 449 llvm::IRBuilderBase::InsertPointGuard guard(builder); 450 continuationBlockPHIs->reserve(continuationBlockPHITypes.size()); 451 builder.SetInsertPoint(continuationBlock, continuationBlock->begin()); 452 for (llvm::Type *ty : continuationBlockPHITypes) 453 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields)); 454 } 455 456 // Convert blocks one by one in topological order to ensure 457 // defs are converted before uses. 458 SetVector<Block *> blocks = getBlocksSortedByDominance(region); 459 for (Block *bb : blocks) { 460 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb); 461 // Retarget the branch of the entry block to the entry block of the 462 // converted region (regions are single-entry). 463 if (bb->isEntryBlock()) { 464 assert(sourceTerminator->getNumSuccessors() == 1 && 465 "provided entry block has multiple successors"); 466 assert(sourceTerminator->getSuccessor(0) == continuationBlock && 467 "ContinuationBlock is not the successor of the entry block"); 468 sourceTerminator->setSuccessor(0, llvmBB); 469 } 470 471 llvm::IRBuilderBase::InsertPointGuard guard(builder); 472 if (failed( 473 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) 474 return llvm::make_error<PreviouslyReportedError>(); 475 476 // Special handling for `omp.yield` and `omp.terminator` (we may have more 477 // than one): they return the control to the parent OpenMP dialect operation 478 // so replace them with the branch to the continuation block. We handle this 479 // here to avoid relying inter-function communication through the 480 // ModuleTranslation class to set up the correct insertion point. This is 481 // also consistent with MLIR's idiom of handling special region terminators 482 // in the same code that handles the region-owning operation. 483 Operation *terminator = bb->getTerminator(); 484 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) { 485 builder.CreateBr(continuationBlock); 486 487 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i) 488 (*continuationBlockPHIs)[i]->addIncoming( 489 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB); 490 } 491 } 492 // After all blocks have been traversed and values mapped, connect the PHI 493 // nodes to the results of preceding blocks. 494 LLVM::detail::connectPHINodes(region, moduleTranslation); 495 496 // Remove the blocks and values defined in this region from the mapping since 497 // they are not visible outside of this region. This allows the same region to 498 // be converted several times, that is cloned, without clashes, and slightly 499 // speeds up the lookups. 500 moduleTranslation.forgetMapping(region); 501 502 return continuationBlock; 503 } 504 505 /// Convert ProcBindKind from MLIR-generated enum to LLVM enum. 506 static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) { 507 switch (kind) { 508 case omp::ClauseProcBindKind::Close: 509 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close; 510 case omp::ClauseProcBindKind::Master: 511 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master; 512 case omp::ClauseProcBindKind::Primary: 513 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary; 514 case omp::ClauseProcBindKind::Spread: 515 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread; 516 } 517 llvm_unreachable("Unknown ClauseProcBindKind kind"); 518 } 519 520 /// Helper function to map block arguments defined by ignored loop wrappers to 521 /// LLVM values and prevent any uses of those from triggering null pointer 522 /// dereferences. 523 /// 524 /// This must be called after block arguments of parent wrappers have already 525 /// been mapped to LLVM IR values. 526 static LogicalResult 527 convertIgnoredWrapper(omp::LoopWrapperInterface &opInst, 528 LLVM::ModuleTranslation &moduleTranslation) { 529 // Map block arguments directly to the LLVM value associated to the 530 // corresponding operand. This is semantically equivalent to this wrapper not 531 // being present. 532 auto forwardArgs = 533 [&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs, 534 OperandRange operands) { 535 for (auto [arg, var] : llvm::zip_equal(blockArgs, operands)) 536 moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var)); 537 }; 538 539 return llvm::TypeSwitch<Operation *, LogicalResult>(opInst) 540 .Case([&](omp::SimdOp op) { 541 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op); 542 forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars()); 543 forwardArgs(blockArgIface.getReductionBlockArgs(), 544 op.getReductionVars()); 545 op.emitWarning() << "simd information on composite construct discarded"; 546 return success(); 547 }) 548 .Default([&](Operation *op) { 549 return op->emitError() << "cannot ignore nested wrapper"; 550 }); 551 } 552 553 /// Helper function to call \c convertIgnoredWrapper() for all wrappers of the 554 /// given \c loopOp nested inside of \c parentOp. This has the effect of mapping 555 /// entry block arguments defined by these operations to outside values. 556 /// 557 /// It must be called after block arguments of \c parentOp have already been 558 /// mapped themselves. 559 static LogicalResult 560 convertIgnoredWrappers(omp::LoopNestOp loopOp, 561 omp::LoopWrapperInterface parentOp, 562 LLVM::ModuleTranslation &moduleTranslation) { 563 SmallVector<omp::LoopWrapperInterface> wrappers; 564 loopOp.gatherWrappers(wrappers); 565 566 // Process wrappers nested inside of `parentOp` from outermost to innermost. 567 for (auto it = 568 std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp)); 569 it != wrappers.rend(); ++it) { 570 if (failed(convertIgnoredWrapper(*it, moduleTranslation))) 571 return failure(); 572 } 573 574 return success(); 575 } 576 577 /// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder. 578 static LogicalResult 579 convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, 580 LLVM::ModuleTranslation &moduleTranslation) { 581 auto maskedOp = cast<omp::MaskedOp>(opInst); 582 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 583 584 if (failed(checkImplementationStatus(opInst))) 585 return failure(); 586 587 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { 588 // MaskedOp has only one region associated with it. 589 auto ®ion = maskedOp.getRegion(); 590 builder.restoreIP(codeGenIP); 591 return convertOmpOpRegions(region, "omp.masked.region", builder, 592 moduleTranslation) 593 .takeError(); 594 }; 595 596 // TODO: Perform finalization actions for variables. This has to be 597 // called for variables which have destructors/finalizers. 598 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); }; 599 600 llvm::Value *filterVal = nullptr; 601 if (auto filterVar = maskedOp.getFilteredThreadId()) { 602 filterVal = moduleTranslation.lookupValue(filterVar); 603 } else { 604 llvm::LLVMContext &llvmContext = builder.getContext(); 605 filterVal = 606 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0); 607 } 608 assert(filterVal != nullptr); 609 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 610 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 611 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB, 612 finiCB, filterVal); 613 614 if (failed(handleError(afterIP, opInst))) 615 return failure(); 616 617 builder.restoreIP(*afterIP); 618 return success(); 619 } 620 621 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder. 622 static LogicalResult 623 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, 624 LLVM::ModuleTranslation &moduleTranslation) { 625 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 626 auto masterOp = cast<omp::MasterOp>(opInst); 627 628 if (failed(checkImplementationStatus(opInst))) 629 return failure(); 630 631 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { 632 // MasterOp has only one region associated with it. 633 auto ®ion = masterOp.getRegion(); 634 builder.restoreIP(codeGenIP); 635 return convertOmpOpRegions(region, "omp.master.region", builder, 636 moduleTranslation) 637 .takeError(); 638 }; 639 640 // TODO: Perform finalization actions for variables. This has to be 641 // called for variables which have destructors/finalizers. 642 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); }; 643 644 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 645 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 646 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB, 647 finiCB); 648 649 if (failed(handleError(afterIP, opInst))) 650 return failure(); 651 652 builder.restoreIP(*afterIP); 653 return success(); 654 } 655 656 /// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder. 657 static LogicalResult 658 convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, 659 LLVM::ModuleTranslation &moduleTranslation) { 660 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 661 auto criticalOp = cast<omp::CriticalOp>(opInst); 662 663 if (failed(checkImplementationStatus(opInst))) 664 return failure(); 665 666 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { 667 // CriticalOp has only one region associated with it. 668 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion(); 669 builder.restoreIP(codeGenIP); 670 return convertOmpOpRegions(region, "omp.critical.region", builder, 671 moduleTranslation) 672 .takeError(); 673 }; 674 675 // TODO: Perform finalization actions for variables. This has to be 676 // called for variables which have destructors/finalizers. 677 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); }; 678 679 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 680 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); 681 llvm::Constant *hint = nullptr; 682 683 // If it has a name, it probably has a hint too. 684 if (criticalOp.getNameAttr()) { 685 // The verifiers in OpenMP Dialect guarentee that all the pointers are 686 // non-null 687 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr()); 688 auto criticalDeclareOp = 689 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp, 690 symbolRef); 691 hint = 692 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 693 static_cast<int>(criticalDeclareOp.getHint())); 694 } 695 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 696 moduleTranslation.getOpenMPBuilder()->createCritical( 697 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint); 698 699 if (failed(handleError(afterIP, opInst))) 700 return failure(); 701 702 builder.restoreIP(*afterIP); 703 return success(); 704 } 705 706 /// Populates `privatizations` with privatization declarations used for the 707 /// given op. 708 template <class OP> 709 static void collectPrivatizationDecls( 710 OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) { 711 std::optional<ArrayAttr> attr = op.getPrivateSyms(); 712 if (!attr) 713 return; 714 715 privatizations.reserve(privatizations.size() + attr->size()); 716 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) { 717 privatizations.push_back(findPrivatizer(op, symbolRef)); 718 } 719 } 720 721 /// Populates `reductions` with reduction declarations used in the given op. 722 template <typename T> 723 static void 724 collectReductionDecls(T op, 725 SmallVectorImpl<omp::DeclareReductionOp> &reductions) { 726 std::optional<ArrayAttr> attr = op.getReductionSyms(); 727 if (!attr) 728 return; 729 730 reductions.reserve(reductions.size() + op.getNumReductionVars()); 731 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) { 732 reductions.push_back( 733 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>( 734 op, symbolRef)); 735 } 736 } 737 738 /// Translates the blocks contained in the given region and appends them to at 739 /// the current insertion point of `builder`. The operations of the entry block 740 /// are appended to the current insertion block. If set, `continuationBlockArgs` 741 /// is populated with translated values that correspond to the values 742 /// omp.yield'ed from the region. 743 static LogicalResult inlineConvertOmpRegions( 744 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, 745 LLVM::ModuleTranslation &moduleTranslation, 746 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) { 747 if (region.empty()) 748 return success(); 749 750 // Special case for single-block regions that don't create additional blocks: 751 // insert operations without creating additional blocks. 752 if (llvm::hasSingleElement(region)) { 753 llvm::Instruction *potentialTerminator = 754 builder.GetInsertBlock()->empty() ? nullptr 755 : &builder.GetInsertBlock()->back(); 756 757 if (potentialTerminator && potentialTerminator->isTerminator()) 758 potentialTerminator->removeFromParent(); 759 moduleTranslation.mapBlock(®ion.front(), builder.GetInsertBlock()); 760 761 if (failed(moduleTranslation.convertBlock( 762 region.front(), /*ignoreArguments=*/true, builder))) 763 return failure(); 764 765 // The continuation arguments are simply the translated terminator operands. 766 if (continuationBlockArgs) 767 llvm::append_range( 768 *continuationBlockArgs, 769 moduleTranslation.lookupValues(region.front().back().getOperands())); 770 771 // Drop the mapping that is no longer necessary so that the same region can 772 // be processed multiple times. 773 moduleTranslation.forgetMapping(region); 774 775 if (potentialTerminator && potentialTerminator->isTerminator()) { 776 llvm::BasicBlock *block = builder.GetInsertBlock(); 777 if (block->empty()) { 778 // this can happen for really simple reduction init regions e.g. 779 // %0 = llvm.mlir.constant(0 : i32) : i32 780 // omp.yield(%0 : i32) 781 // because the llvm.mlir.constant (MLIR op) isn't converted into any 782 // llvm op 783 potentialTerminator->insertInto(block, block->begin()); 784 } else { 785 potentialTerminator->insertAfter(&block->back()); 786 } 787 } 788 789 return success(); 790 } 791 792 SmallVector<llvm::PHINode *> phis; 793 llvm::Expected<llvm::BasicBlock *> continuationBlock = 794 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis); 795 796 if (failed(handleError(continuationBlock, *region.getParentOp()))) 797 return failure(); 798 799 if (continuationBlockArgs) 800 llvm::append_range(*continuationBlockArgs, phis); 801 builder.SetInsertPoint(*continuationBlock, 802 (*continuationBlock)->getFirstInsertionPt()); 803 return success(); 804 } 805 806 namespace { 807 /// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to 808 /// store lambdas with capture. 809 using OwningReductionGen = 810 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy( 811 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *, 812 llvm::Value *&)>; 813 using OwningAtomicReductionGen = 814 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy( 815 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *, 816 llvm::Value *)>; 817 } // namespace 818 819 /// Create an OpenMPIRBuilder-compatible reduction generator for the given 820 /// reduction declaration. The generator uses `builder` but ignores its 821 /// insertion point. 822 static OwningReductionGen 823 makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, 824 LLVM::ModuleTranslation &moduleTranslation) { 825 // The lambda is mutable because we need access to non-const methods of decl 826 // (which aren't actually mutating it), and we must capture decl by-value to 827 // avoid the dangling reference after the parent function returns. 828 OwningReductionGen gen = 829 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, 830 llvm::Value *lhs, llvm::Value *rhs, 831 llvm::Value *&result) mutable 832 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { 833 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs); 834 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs); 835 builder.restoreIP(insertPoint); 836 SmallVector<llvm::Value *> phis; 837 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(), 838 "omp.reduction.nonatomic.body", builder, 839 moduleTranslation, &phis))) 840 return llvm::createStringError( 841 "failed to inline `combiner` region of `omp.declare_reduction`"); 842 assert(phis.size() == 1); 843 result = phis[0]; 844 return builder.saveIP(); 845 }; 846 return gen; 847 } 848 849 /// Create an OpenMPIRBuilder-compatible atomic reduction generator for the 850 /// given reduction declaration. The generator uses `builder` but ignores its 851 /// insertion point. Returns null if there is no atomic region available in the 852 /// reduction declaration. 853 static OwningAtomicReductionGen 854 makeAtomicReductionGen(omp::DeclareReductionOp decl, 855 llvm::IRBuilderBase &builder, 856 LLVM::ModuleTranslation &moduleTranslation) { 857 if (decl.getAtomicReductionRegion().empty()) 858 return OwningAtomicReductionGen(); 859 860 // The lambda is mutable because we need access to non-const methods of decl 861 // (which aren't actually mutating it), and we must capture decl by-value to 862 // avoid the dangling reference after the parent function returns. 863 OwningAtomicReductionGen atomicGen = 864 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *, 865 llvm::Value *lhs, llvm::Value *rhs) mutable 866 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { 867 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs); 868 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs); 869 builder.restoreIP(insertPoint); 870 SmallVector<llvm::Value *> phis; 871 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(), 872 "omp.reduction.atomic.body", builder, 873 moduleTranslation, &phis))) 874 return llvm::createStringError( 875 "failed to inline `atomic` region of `omp.declare_reduction`"); 876 assert(phis.empty()); 877 return builder.saveIP(); 878 }; 879 return atomicGen; 880 } 881 882 /// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder. 883 static LogicalResult 884 convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, 885 LLVM::ModuleTranslation &moduleTranslation) { 886 auto orderedOp = cast<omp::OrderedOp>(opInst); 887 888 if (failed(checkImplementationStatus(opInst))) 889 return failure(); 890 891 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType(); 892 bool isDependSource = dependType == omp::ClauseDepend::dependsource; 893 unsigned numLoops = *orderedOp.getDoacrossNumLoops(); 894 SmallVector<llvm::Value *> vecValues = 895 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars()); 896 897 size_t indexVecValues = 0; 898 while (indexVecValues < vecValues.size()) { 899 SmallVector<llvm::Value *> storeValues; 900 storeValues.reserve(numLoops); 901 for (unsigned i = 0; i < numLoops; i++) { 902 storeValues.push_back(vecValues[indexVecValues]); 903 indexVecValues++; 904 } 905 llvm::OpenMPIRBuilder::InsertPointTy allocaIP = 906 findAllocaInsertPoint(builder, moduleTranslation); 907 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 908 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend( 909 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource)); 910 } 911 return success(); 912 } 913 914 /// Converts an OpenMP 'ordered_region' operation into LLVM IR using 915 /// OpenMPIRBuilder. 916 static LogicalResult 917 convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, 918 LLVM::ModuleTranslation &moduleTranslation) { 919 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 920 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst); 921 922 if (failed(checkImplementationStatus(opInst))) 923 return failure(); 924 925 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { 926 // OrderedOp has only one region associated with it. 927 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion(); 928 builder.restoreIP(codeGenIP); 929 return convertOmpOpRegions(region, "omp.ordered.region", builder, 930 moduleTranslation) 931 .takeError(); 932 }; 933 934 // TODO: Perform finalization actions for variables. This has to be 935 // called for variables which have destructors/finalizers. 936 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); }; 937 938 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 939 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 940 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd( 941 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd()); 942 943 if (failed(handleError(afterIP, opInst))) 944 return failure(); 945 946 builder.restoreIP(*afterIP); 947 return success(); 948 } 949 950 namespace { 951 /// Contains the arguments for an LLVM store operation 952 struct DeferredStore { 953 DeferredStore(llvm::Value *value, llvm::Value *address) 954 : value(value), address(address) {} 955 956 llvm::Value *value; 957 llvm::Value *address; 958 }; 959 } // namespace 960 961 /// Allocate space for privatized reduction variables. 962 /// `deferredStores` contains information to create store operations which needs 963 /// to be inserted after all allocas 964 template <typename T> 965 static LogicalResult 966 allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs, 967 llvm::IRBuilderBase &builder, 968 LLVM::ModuleTranslation &moduleTranslation, 969 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, 970 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, 971 SmallVectorImpl<llvm::Value *> &privateReductionVariables, 972 DenseMap<Value, llvm::Value *> &reductionVariableMap, 973 SmallVectorImpl<DeferredStore> &deferredStores, 974 llvm::ArrayRef<bool> isByRefs) { 975 llvm::IRBuilderBase::InsertPointGuard guard(builder); 976 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); 977 978 // delay creating stores until after all allocas 979 deferredStores.reserve(loop.getNumReductionVars()); 980 981 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) { 982 Region &allocRegion = reductionDecls[i].getAllocRegion(); 983 if (isByRefs[i]) { 984 if (allocRegion.empty()) 985 continue; 986 987 SmallVector<llvm::Value *, 1> phis; 988 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc", 989 builder, moduleTranslation, &phis))) 990 return loop.emitError( 991 "failed to inline `alloc` region of `omp.declare_reduction`"); 992 993 assert(phis.size() == 1 && "expected one allocation to be yielded"); 994 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); 995 996 // Allocate reduction variable (which is a pointer to the real reduction 997 // variable allocated in the inlined region) 998 llvm::Value *var = builder.CreateAlloca( 999 moduleTranslation.convertType(reductionDecls[i].getType())); 1000 deferredStores.emplace_back(phis[0], var); 1001 1002 privateReductionVariables[i] = var; 1003 moduleTranslation.mapValue(reductionArgs[i], phis[0]); 1004 reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]); 1005 } else { 1006 assert(allocRegion.empty() && 1007 "allocaction is implicit for by-val reduction"); 1008 llvm::Value *var = builder.CreateAlloca( 1009 moduleTranslation.convertType(reductionDecls[i].getType())); 1010 moduleTranslation.mapValue(reductionArgs[i], var); 1011 privateReductionVariables[i] = var; 1012 reductionVariableMap.try_emplace(loop.getReductionVars()[i], var); 1013 } 1014 } 1015 1016 return success(); 1017 } 1018 1019 /// Map input arguments to reduction initialization region 1020 template <typename T> 1021 static void 1022 mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, 1023 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, 1024 DenseMap<Value, llvm::Value *> &reductionVariableMap, 1025 unsigned i) { 1026 // map input argument to the initialization region 1027 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i]; 1028 Region &initializerRegion = reduction.getInitializerRegion(); 1029 Block &entry = initializerRegion.front(); 1030 1031 mlir::Value mlirSource = loop.getReductionVars()[i]; 1032 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource); 1033 assert(llvmSource && "lookup reduction var"); 1034 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource); 1035 1036 if (entry.getNumArguments() > 1) { 1037 llvm::Value *allocation = 1038 reductionVariableMap.lookup(loop.getReductionVars()[i]); 1039 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation); 1040 } 1041 } 1042 1043 /// Inline reductions' `init` regions. This functions assumes that the 1044 /// `builder`'s insertion point is where the user wants the `init` regions to be 1045 /// inlined; i.e. it does not try to find a proper insertion location for the 1046 /// `init` regions. It also leaves the `builder's insertions point in a state 1047 /// where the user can continue the code-gen directly afterwards. 1048 template <typename OP> 1049 static LogicalResult 1050 initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs, 1051 llvm::IRBuilderBase &builder, 1052 LLVM::ModuleTranslation &moduleTranslation, 1053 llvm::BasicBlock *latestAllocaBlock, 1054 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, 1055 SmallVectorImpl<llvm::Value *> &privateReductionVariables, 1056 DenseMap<Value, llvm::Value *> &reductionVariableMap, 1057 llvm::ArrayRef<bool> isByRef, 1058 SmallVectorImpl<DeferredStore> &deferredStores) { 1059 if (op.getNumReductionVars() == 0) 1060 return success(); 1061 1062 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init"); 1063 auto allocaIP = llvm::IRBuilderBase::InsertPoint( 1064 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator()); 1065 builder.restoreIP(allocaIP); 1066 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars()); 1067 1068 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) { 1069 if (isByRef[i]) { 1070 if (!reductionDecls[i].getAllocRegion().empty()) 1071 continue; 1072 1073 // TODO: remove after all users of by-ref are updated to use the alloc 1074 // region: Allocate reduction variable (which is a pointer to the real 1075 // reduciton variable allocated in the inlined region) 1076 byRefVars[i] = builder.CreateAlloca( 1077 moduleTranslation.convertType(reductionDecls[i].getType())); 1078 } 1079 } 1080 1081 if (initBlock->empty() || initBlock->getTerminator() == nullptr) 1082 builder.SetInsertPoint(initBlock); 1083 else 1084 builder.SetInsertPoint(initBlock->getTerminator()); 1085 1086 // store result of the alloc region to the allocated pointer to the real 1087 // reduction variable 1088 for (auto [data, addr] : deferredStores) 1089 builder.CreateStore(data, addr); 1090 1091 // Before the loop, store the initial values of reductions into reduction 1092 // variables. Although this could be done after allocas, we don't want to mess 1093 // up with the alloca insertion point. 1094 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) { 1095 SmallVector<llvm::Value *, 1> phis; 1096 1097 // map block argument to initializer region 1098 mapInitializationArgs(op, moduleTranslation, reductionDecls, 1099 reductionVariableMap, i); 1100 1101 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(), 1102 "omp.reduction.neutral", builder, 1103 moduleTranslation, &phis))) 1104 return failure(); 1105 1106 assert(phis.size() == 1 && "expected one value to be yielded from the " 1107 "reduction neutral element declaration region"); 1108 1109 if (builder.GetInsertBlock()->empty() || 1110 builder.GetInsertBlock()->getTerminator() == nullptr) 1111 builder.SetInsertPoint(builder.GetInsertBlock()); 1112 else 1113 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator()); 1114 1115 if (isByRef[i]) { 1116 if (!reductionDecls[i].getAllocRegion().empty()) 1117 // done in allocReductionVars 1118 continue; 1119 1120 // TODO: this path can be removed once all users of by-ref are updated to 1121 // use an alloc region 1122 1123 // Store the result of the inlined region to the allocated reduction var 1124 // ptr 1125 builder.CreateStore(phis[0], byRefVars[i]); 1126 1127 privateReductionVariables[i] = byRefVars[i]; 1128 moduleTranslation.mapValue(reductionArgs[i], phis[0]); 1129 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]); 1130 } else { 1131 // for by-ref case the store is inside of the reduction region 1132 builder.CreateStore(phis[0], privateReductionVariables[i]); 1133 // the rest was handled in allocByValReductionVars 1134 } 1135 1136 // forget the mapping for the initializer region because we might need a 1137 // different mapping if this reduction declaration is re-used for a 1138 // different variable 1139 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion()); 1140 } 1141 1142 return success(); 1143 } 1144 1145 /// Collect reduction info 1146 template <typename T> 1147 static void collectReductionInfo( 1148 T loop, llvm::IRBuilderBase &builder, 1149 LLVM::ModuleTranslation &moduleTranslation, 1150 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, 1151 SmallVectorImpl<OwningReductionGen> &owningReductionGens, 1152 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens, 1153 const ArrayRef<llvm::Value *> privateReductionVariables, 1154 SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) { 1155 unsigned numReductions = loop.getNumReductionVars(); 1156 1157 for (unsigned i = 0; i < numReductions; ++i) { 1158 owningReductionGens.push_back( 1159 makeReductionGen(reductionDecls[i], builder, moduleTranslation)); 1160 owningAtomicReductionGens.push_back( 1161 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation)); 1162 } 1163 1164 // Collect the reduction information. 1165 reductionInfos.reserve(numReductions); 1166 for (unsigned i = 0; i < numReductions; ++i) { 1167 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr; 1168 if (owningAtomicReductionGens[i]) 1169 atomicGen = owningAtomicReductionGens[i]; 1170 llvm::Value *variable = 1171 moduleTranslation.lookupValue(loop.getReductionVars()[i]); 1172 reductionInfos.push_back( 1173 {moduleTranslation.convertType(reductionDecls[i].getType()), variable, 1174 privateReductionVariables[i], 1175 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar, 1176 owningReductionGens[i], 1177 /*ReductionGenClang=*/nullptr, atomicGen}); 1178 } 1179 } 1180 1181 /// handling of DeclareReductionOp's cleanup region 1182 static LogicalResult 1183 inlineOmpRegionCleanup(llvm::SmallVectorImpl<Region *> &cleanupRegions, 1184 llvm::ArrayRef<llvm::Value *> privateVariables, 1185 LLVM::ModuleTranslation &moduleTranslation, 1186 llvm::IRBuilderBase &builder, StringRef regionName, 1187 bool shouldLoadCleanupRegionArg = true) { 1188 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) { 1189 if (cleanupRegion->empty()) 1190 continue; 1191 1192 // map the argument to the cleanup region 1193 Block &entry = cleanupRegion->front(); 1194 1195 llvm::Instruction *potentialTerminator = 1196 builder.GetInsertBlock()->empty() ? nullptr 1197 : &builder.GetInsertBlock()->back(); 1198 if (potentialTerminator && potentialTerminator->isTerminator()) 1199 builder.SetInsertPoint(potentialTerminator); 1200 llvm::Value *privateVarValue = 1201 shouldLoadCleanupRegionArg 1202 ? builder.CreateLoad( 1203 moduleTranslation.convertType(entry.getArgument(0).getType()), 1204 privateVariables[i]) 1205 : privateVariables[i]; 1206 1207 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue); 1208 1209 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder, 1210 moduleTranslation))) 1211 return failure(); 1212 1213 // clear block argument mapping in case it needs to be re-created with a 1214 // different source for another use of the same reduction decl 1215 moduleTranslation.forgetMapping(*cleanupRegion); 1216 } 1217 return success(); 1218 } 1219 1220 // TODO: not used by ParallelOp 1221 template <class OP> 1222 static LogicalResult createReductionsAndCleanup( 1223 OP op, llvm::IRBuilderBase &builder, 1224 LLVM::ModuleTranslation &moduleTranslation, 1225 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, 1226 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, 1227 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef) { 1228 // Process the reductions if required. 1229 if (op.getNumReductionVars() == 0) 1230 return success(); 1231 1232 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 1233 1234 // Create the reduction generators. We need to own them here because 1235 // ReductionInfo only accepts references to the generators. 1236 SmallVector<OwningReductionGen> owningReductionGens; 1237 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens; 1238 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos; 1239 collectReductionInfo(op, builder, moduleTranslation, reductionDecls, 1240 owningReductionGens, owningAtomicReductionGens, 1241 privateReductionVariables, reductionInfos); 1242 1243 // The call to createReductions below expects the block to have a 1244 // terminator. Create an unreachable instruction to serve as terminator 1245 // and remove it later. 1246 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable(); 1247 builder.SetInsertPoint(tempTerminator); 1248 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint = 1249 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos, 1250 isByRef, op.getNowait()); 1251 1252 if (failed(handleError(contInsertPoint, *op))) 1253 return failure(); 1254 1255 if (!contInsertPoint->getBlock()) 1256 return op->emitOpError() << "failed to convert reductions"; 1257 1258 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 1259 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for); 1260 1261 if (failed(handleError(afterIP, *op))) 1262 return failure(); 1263 1264 tempTerminator->eraseFromParent(); 1265 builder.restoreIP(*afterIP); 1266 1267 // after the construct, deallocate private reduction variables 1268 SmallVector<Region *> reductionRegions; 1269 llvm::transform(reductionDecls, std::back_inserter(reductionRegions), 1270 [](omp::DeclareReductionOp reductionDecl) { 1271 return &reductionDecl.getCleanupRegion(); 1272 }); 1273 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables, 1274 moduleTranslation, builder, 1275 "omp.reduction.cleanup"); 1276 return success(); 1277 } 1278 1279 static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) { 1280 if (!attr) 1281 return {}; 1282 return *attr; 1283 } 1284 1285 // TODO: not used by omp.parallel 1286 template <typename OP> 1287 static LogicalResult allocAndInitializeReductionVars( 1288 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder, 1289 LLVM::ModuleTranslation &moduleTranslation, 1290 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, 1291 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, 1292 SmallVectorImpl<llvm::Value *> &privateReductionVariables, 1293 DenseMap<Value, llvm::Value *> &reductionVariableMap, 1294 llvm::ArrayRef<bool> isByRef) { 1295 if (op.getNumReductionVars() == 0) 1296 return success(); 1297 1298 SmallVector<DeferredStore> deferredStores; 1299 1300 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation, 1301 allocaIP, reductionDecls, 1302 privateReductionVariables, reductionVariableMap, 1303 deferredStores, isByRef))) 1304 return failure(); 1305 1306 return initReductionVars(op, reductionArgs, builder, moduleTranslation, 1307 allocaIP.getBlock(), reductionDecls, 1308 privateReductionVariables, reductionVariableMap, 1309 isByRef, deferredStores); 1310 } 1311 1312 /// Return the llvm::Value * corresponding to the `privateVar` that 1313 /// is being privatized. It isn't always as simple as looking up 1314 /// moduleTranslation with privateVar. For instance, in case of 1315 /// an allocatable, the descriptor for the allocatable is privatized. 1316 /// This descriptor is mapped using an MapInfoOp. So, this function 1317 /// will return a pointer to the llvm::Value corresponding to the 1318 /// block argument for the mapped descriptor. 1319 static llvm::Value * 1320 findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, 1321 LLVM::ModuleTranslation &moduleTranslation, 1322 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) { 1323 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar)) 1324 return moduleTranslation.lookupValue(privateVar); 1325 1326 Value blockArg = (*mappedPrivateVars)[privateVar]; 1327 Type privVarType = privateVar.getType(); 1328 Type blockArgType = blockArg.getType(); 1329 assert(isa<LLVM::LLVMPointerType>(blockArgType) && 1330 "A block argument corresponding to a mapped var should have " 1331 "!llvm.ptr type"); 1332 1333 if (privVarType == blockArgType) 1334 return moduleTranslation.lookupValue(blockArg); 1335 1336 // This typically happens when the privatized type is lowered from 1337 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the 1338 // struct/pair is passed by value. But, mapped values are passed only as 1339 // pointers, so before we privatize, we must load the pointer. 1340 if (!isa<LLVM::LLVMPointerType>(privVarType)) 1341 return builder.CreateLoad(moduleTranslation.convertType(privVarType), 1342 moduleTranslation.lookupValue(blockArg)); 1343 1344 return moduleTranslation.lookupValue(privateVar); 1345 } 1346 1347 /// Allocate delayed private variables. Returns the basic block which comes 1348 /// after all of these allocations. llvm::Value * for each of these private 1349 /// variables are populated in llvmPrivateVars. 1350 static llvm::Expected<llvm::BasicBlock *> 1351 allocatePrivateVars(llvm::IRBuilderBase &builder, 1352 LLVM::ModuleTranslation &moduleTranslation, 1353 MutableArrayRef<BlockArgument> privateBlockArgs, 1354 MutableArrayRef<omp::PrivateClauseOp> privateDecls, 1355 MutableArrayRef<mlir::Value> mlirPrivateVars, 1356 llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars, 1357 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, 1358 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) { 1359 // Allocate private vars 1360 llvm::BranchInst *allocaTerminator = 1361 llvm::cast<llvm::BranchInst>(allocaIP.getBlock()->getTerminator()); 1362 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(), 1363 allocaTerminator->getIterator()), 1364 true, "omp.region.after_alloca"); 1365 1366 llvm::IRBuilderBase::InsertPointGuard guard(builder); 1367 // Update the allocaTerminator in case the alloca block was split above. 1368 allocaTerminator = 1369 llvm::cast<llvm::BranchInst>(allocaIP.getBlock()->getTerminator()); 1370 builder.SetInsertPoint(allocaTerminator); 1371 assert(allocaTerminator->getNumSuccessors() == 1 && 1372 "This is an unconditional branch created by OpenMPIRBuilder"); 1373 1374 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0); 1375 1376 // FIXME: Some of the allocation regions do more than just allocating. 1377 // They read from their block argument (amongst other non-alloca things). 1378 // When OpenMPIRBuilder outlines the parallel region into a different 1379 // function it places the loads for live in-values (such as these block 1380 // arguments) at the end of the entry block (because the entry block is 1381 // assumed to contain only allocas). Therefore, if we put these complicated 1382 // alloc blocks in the entry block, these will not dominate the availability 1383 // of the live-in values they are using. Fix this by adding a latealloc 1384 // block after the entry block to put these in (this also helps to avoid 1385 // mixing non-alloca code with allocas). 1386 // Alloc regions which do not use the block argument can still be placed in 1387 // the entry block (therefore keeping the allocas together). 1388 llvm::BasicBlock *privAllocBlock = nullptr; 1389 if (!privateBlockArgs.empty()) 1390 privAllocBlock = splitBB(builder, true, "omp.private.latealloc"); 1391 for (auto [privDecl, mlirPrivVar, blockArg] : 1392 llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) { 1393 Region &allocRegion = privDecl.getAllocRegion(); 1394 1395 // map allocation region block argument 1396 llvm::Value *nonPrivateVar = findAssociatedValue( 1397 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars); 1398 assert(nonPrivateVar); 1399 moduleTranslation.mapValue(privDecl.getAllocMoldArg(), nonPrivateVar); 1400 1401 // in-place convert the private allocation region 1402 SmallVector<llvm::Value *, 1> phis; 1403 if (privDecl.getAllocMoldArg().getUses().empty()) { 1404 // TODO this should use 1405 // allocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca() so it goes before 1406 // the code for fetching the thread id. Not doing this for now to avoid 1407 // test churn. 1408 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); 1409 } else { 1410 builder.SetInsertPoint(privAllocBlock->getTerminator()); 1411 } 1412 1413 if (failed(inlineConvertOmpRegions(allocRegion, "omp.private.alloc", 1414 builder, moduleTranslation, &phis))) 1415 return llvm::createStringError( 1416 "failed to inline `alloc` region of `omp.private`"); 1417 1418 assert(phis.size() == 1 && "expected one allocation to be yielded"); 1419 1420 moduleTranslation.mapValue(blockArg, phis[0]); 1421 llvmPrivateVars.push_back(phis[0]); 1422 1423 // clear alloc region block argument mapping in case it needs to be 1424 // re-created with a different source for another use of the same 1425 // reduction decl 1426 moduleTranslation.forgetMapping(allocRegion); 1427 } 1428 return afterAllocas; 1429 } 1430 1431 static LogicalResult 1432 initFirstPrivateVars(llvm::IRBuilderBase &builder, 1433 LLVM::ModuleTranslation &moduleTranslation, 1434 SmallVectorImpl<mlir::Value> &mlirPrivateVars, 1435 SmallVectorImpl<llvm::Value *> &llvmPrivateVars, 1436 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, 1437 llvm::BasicBlock *afterAllocas) { 1438 llvm::IRBuilderBase::InsertPointGuard guard(builder); 1439 // Apply copy region for firstprivate. 1440 bool needsFirstprivate = 1441 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) { 1442 return privOp.getDataSharingType() == 1443 omp::DataSharingClauseType::FirstPrivate; 1444 }); 1445 1446 if (!needsFirstprivate) 1447 return success(); 1448 1449 assert(afterAllocas->getSinglePredecessor()); 1450 1451 // Find the end of the allocation blocks 1452 builder.SetInsertPoint(afterAllocas->getSinglePredecessor()->getTerminator()); 1453 llvm::BasicBlock *copyBlock = 1454 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy"); 1455 builder.SetInsertPoint(copyBlock->getFirstNonPHIOrDbgOrAlloca()); 1456 1457 for (auto [decl, mlirVar, llvmVar] : 1458 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) { 1459 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate) 1460 continue; 1461 1462 // copyRegion implements `lhs = rhs` 1463 Region ©Region = decl.getCopyRegion(); 1464 1465 // map copyRegion rhs arg 1466 llvm::Value *nonPrivateVar = moduleTranslation.lookupValue(mlirVar); 1467 assert(nonPrivateVar); 1468 moduleTranslation.mapValue(decl.getCopyMoldArg(), nonPrivateVar); 1469 1470 // map copyRegion lhs arg 1471 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar); 1472 1473 // in-place convert copy region 1474 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator()); 1475 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder, 1476 moduleTranslation))) 1477 return decl.emitError("failed to inline `copy` region of `omp.private`"); 1478 1479 // ignore unused value yielded from copy region 1480 1481 // clear copy region block argument mapping in case it needs to be 1482 // re-created with different sources for reuse of the same reduction 1483 // decl 1484 moduleTranslation.forgetMapping(copyRegion); 1485 } 1486 1487 return success(); 1488 } 1489 1490 static LogicalResult 1491 cleanupPrivateVars(llvm::IRBuilderBase &builder, 1492 LLVM::ModuleTranslation &moduleTranslation, Location loc, 1493 SmallVectorImpl<llvm::Value *> &llvmPrivateVars, 1494 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls) { 1495 // private variable deallocation 1496 SmallVector<Region *> privateCleanupRegions; 1497 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions), 1498 [](omp::PrivateClauseOp privatizer) { 1499 return &privatizer.getDeallocRegion(); 1500 }); 1501 1502 if (failed(inlineOmpRegionCleanup( 1503 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder, 1504 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false))) 1505 return mlir::emitError(loc, "failed to inline `dealloc` region of an " 1506 "`omp.private` op in"); 1507 1508 return success(); 1509 } 1510 1511 static LogicalResult 1512 convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, 1513 LLVM::ModuleTranslation &moduleTranslation) { 1514 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 1515 using StorableBodyGenCallbackTy = 1516 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy; 1517 1518 auto sectionsOp = cast<omp::SectionsOp>(opInst); 1519 1520 if (failed(checkImplementationStatus(opInst))) 1521 return failure(); 1522 1523 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref()); 1524 assert(isByRef.size() == sectionsOp.getNumReductionVars()); 1525 1526 SmallVector<omp::DeclareReductionOp> reductionDecls; 1527 collectReductionDecls(sectionsOp, reductionDecls); 1528 llvm::OpenMPIRBuilder::InsertPointTy allocaIP = 1529 findAllocaInsertPoint(builder, moduleTranslation); 1530 1531 SmallVector<llvm::Value *> privateReductionVariables( 1532 sectionsOp.getNumReductionVars()); 1533 DenseMap<Value, llvm::Value *> reductionVariableMap; 1534 1535 MutableArrayRef<BlockArgument> reductionArgs = 1536 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs(); 1537 1538 if (failed(allocAndInitializeReductionVars( 1539 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP, 1540 reductionDecls, privateReductionVariables, reductionVariableMap, 1541 isByRef))) 1542 return failure(); 1543 1544 // Store the mapping between reduction variables and their private copies on 1545 // ModuleTranslation stack. It can be then recovered when translating 1546 // omp.reduce operations in a separate call. 1547 LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard( 1548 moduleTranslation, reductionVariableMap); 1549 1550 SmallVector<StorableBodyGenCallbackTy> sectionCBs; 1551 1552 for (Operation &op : *sectionsOp.getRegion().begin()) { 1553 auto sectionOp = dyn_cast<omp::SectionOp>(op); 1554 if (!sectionOp) // omp.terminator 1555 continue; 1556 1557 Region ®ion = sectionOp.getRegion(); 1558 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation]( 1559 InsertPointTy allocaIP, InsertPointTy codeGenIP) { 1560 builder.restoreIP(codeGenIP); 1561 1562 // map the omp.section reduction block argument to the omp.sections block 1563 // arguments 1564 // TODO: this assumes that the only block arguments are reduction 1565 // variables 1566 assert(region.getNumArguments() == 1567 sectionsOp.getRegion().getNumArguments()); 1568 for (auto [sectionsArg, sectionArg] : llvm::zip_equal( 1569 sectionsOp.getRegion().getArguments(), region.getArguments())) { 1570 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg); 1571 assert(llvmVal); 1572 moduleTranslation.mapValue(sectionArg, llvmVal); 1573 } 1574 1575 return convertOmpOpRegions(region, "omp.section.region", builder, 1576 moduleTranslation) 1577 .takeError(); 1578 }; 1579 sectionCBs.push_back(sectionCB); 1580 } 1581 1582 // No sections within omp.sections operation - skip generation. This situation 1583 // is only possible if there is only a terminator operation inside the 1584 // sections operation 1585 if (sectionCBs.empty()) 1586 return success(); 1587 1588 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin())); 1589 1590 // TODO: Perform appropriate actions according to the data-sharing 1591 // attribute (shared, private, firstprivate, ...) of variables. 1592 // Currently defaults to shared. 1593 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &, 1594 llvm::Value &vPtr, llvm::Value *&replacementValue) 1595 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { 1596 replacementValue = &vPtr; 1597 return codeGenIP; 1598 }; 1599 1600 // TODO: Perform finalization actions for variables. This has to be 1601 // called for variables which have destructors/finalizers. 1602 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); }; 1603 1604 allocaIP = findAllocaInsertPoint(builder, moduleTranslation); 1605 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 1606 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 1607 moduleTranslation.getOpenMPBuilder()->createSections( 1608 ompLoc, allocaIP, sectionCBs, privCB, finiCB, false, 1609 sectionsOp.getNowait()); 1610 1611 if (failed(handleError(afterIP, opInst))) 1612 return failure(); 1613 1614 builder.restoreIP(*afterIP); 1615 1616 // Process the reductions if required. 1617 return createReductionsAndCleanup(sectionsOp, builder, moduleTranslation, 1618 allocaIP, reductionDecls, 1619 privateReductionVariables, isByRef); 1620 } 1621 1622 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder. 1623 static LogicalResult 1624 convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, 1625 LLVM::ModuleTranslation &moduleTranslation) { 1626 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 1627 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 1628 1629 if (failed(checkImplementationStatus(*singleOp))) 1630 return failure(); 1631 1632 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { 1633 builder.restoreIP(codegenIP); 1634 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region", 1635 builder, moduleTranslation) 1636 .takeError(); 1637 }; 1638 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); }; 1639 1640 // Handle copyprivate 1641 Operation::operand_range cpVars = singleOp.getCopyprivateVars(); 1642 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms(); 1643 llvm::SmallVector<llvm::Value *> llvmCPVars; 1644 llvm::SmallVector<llvm::Function *> llvmCPFuncs; 1645 for (size_t i = 0, e = cpVars.size(); i < e; ++i) { 1646 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i])); 1647 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>( 1648 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i])); 1649 llvmCPFuncs.push_back( 1650 moduleTranslation.lookupFunction(llvmFuncOp.getName())); 1651 } 1652 1653 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 1654 moduleTranslation.getOpenMPBuilder()->createSingle( 1655 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, 1656 llvmCPFuncs); 1657 1658 if (failed(handleError(afterIP, *singleOp))) 1659 return failure(); 1660 1661 builder.restoreIP(*afterIP); 1662 return success(); 1663 } 1664 1665 // Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder 1666 static LogicalResult 1667 convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, 1668 LLVM::ModuleTranslation &moduleTranslation) { 1669 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 1670 if (failed(checkImplementationStatus(*op))) 1671 return failure(); 1672 1673 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { 1674 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame( 1675 moduleTranslation, allocaIP); 1676 builder.restoreIP(codegenIP); 1677 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder, 1678 moduleTranslation) 1679 .takeError(); 1680 }; 1681 1682 llvm::Value *numTeamsLower = nullptr; 1683 if (Value numTeamsLowerVar = op.getNumTeamsLower()) 1684 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar); 1685 1686 llvm::Value *numTeamsUpper = nullptr; 1687 if (Value numTeamsUpperVar = op.getNumTeamsUpper()) 1688 numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar); 1689 1690 llvm::Value *threadLimit = nullptr; 1691 if (Value threadLimitVar = op.getThreadLimit()) 1692 threadLimit = moduleTranslation.lookupValue(threadLimitVar); 1693 1694 llvm::Value *ifExpr = nullptr; 1695 if (Value ifVar = op.getIfExpr()) 1696 ifExpr = moduleTranslation.lookupValue(ifVar); 1697 1698 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 1699 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 1700 moduleTranslation.getOpenMPBuilder()->createTeams( 1701 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr); 1702 1703 if (failed(handleError(afterIP, *op))) 1704 return failure(); 1705 1706 builder.restoreIP(*afterIP); 1707 return success(); 1708 } 1709 1710 static void 1711 buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars, 1712 LLVM::ModuleTranslation &moduleTranslation, 1713 SmallVectorImpl<llvm::OpenMPIRBuilder::DependData> &dds) { 1714 if (dependVars.empty()) 1715 return; 1716 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) { 1717 llvm::omp::RTLDependenceKindTy type; 1718 switch ( 1719 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) { 1720 case mlir::omp::ClauseTaskDepend::taskdependin: 1721 type = llvm::omp::RTLDependenceKindTy::DepIn; 1722 break; 1723 // The OpenMP runtime requires that the codegen for 'depend' clause for 1724 // 'out' dependency kind must be the same as codegen for 'depend' clause 1725 // with 'inout' dependency. 1726 case mlir::omp::ClauseTaskDepend::taskdependout: 1727 case mlir::omp::ClauseTaskDepend::taskdependinout: 1728 type = llvm::omp::RTLDependenceKindTy::DepInOut; 1729 break; 1730 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset: 1731 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet; 1732 break; 1733 case mlir::omp::ClauseTaskDepend::taskdependinoutset: 1734 type = llvm::omp::RTLDependenceKindTy::DepInOutSet; 1735 break; 1736 }; 1737 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep)); 1738 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal); 1739 dds.emplace_back(dd); 1740 } 1741 } 1742 1743 /// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder. 1744 static LogicalResult 1745 convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, 1746 LLVM::ModuleTranslation &moduleTranslation) { 1747 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 1748 if (failed(checkImplementationStatus(*taskOp))) 1749 return failure(); 1750 1751 // Collect delayed privatisation declarations 1752 MutableArrayRef<BlockArgument> privateBlockArgs = 1753 cast<omp::BlockArgOpenMPOpInterface>(*taskOp).getPrivateBlockArgs(); 1754 SmallVector<mlir::Value> mlirPrivateVars; 1755 SmallVector<llvm::Value *> llvmPrivateVars; 1756 SmallVector<omp::PrivateClauseOp> privateDecls; 1757 mlirPrivateVars.reserve(privateBlockArgs.size()); 1758 llvmPrivateVars.reserve(privateBlockArgs.size()); 1759 collectPrivatizationDecls(taskOp, privateDecls); 1760 for (mlir::Value privateVar : taskOp.getPrivateVars()) 1761 mlirPrivateVars.push_back(privateVar); 1762 1763 auto bodyCB = [&](InsertPointTy allocaIP, 1764 InsertPointTy codegenIP) -> llvm::Error { 1765 // Save the alloca insertion point on ModuleTranslation stack for use in 1766 // nested regions. 1767 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame( 1768 moduleTranslation, allocaIP); 1769 1770 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars( 1771 builder, moduleTranslation, privateBlockArgs, privateDecls, 1772 mlirPrivateVars, llvmPrivateVars, allocaIP); 1773 if (handleError(afterAllocas, *taskOp).failed()) 1774 return llvm::make_error<PreviouslyReportedError>(); 1775 1776 if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars, 1777 llvmPrivateVars, privateDecls, 1778 afterAllocas.get()))) 1779 return llvm::make_error<PreviouslyReportedError>(); 1780 1781 // translate the body of the task: 1782 builder.restoreIP(codegenIP); 1783 auto continuationBlockOrError = convertOmpOpRegions( 1784 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation); 1785 if (failed(handleError(continuationBlockOrError, *taskOp))) 1786 return llvm::make_error<PreviouslyReportedError>(); 1787 1788 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator()); 1789 1790 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(), 1791 llvmPrivateVars, privateDecls))) 1792 return llvm::make_error<PreviouslyReportedError>(); 1793 1794 return llvm::Error::success(); 1795 }; 1796 1797 SmallVector<llvm::OpenMPIRBuilder::DependData> dds; 1798 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(), 1799 moduleTranslation, dds); 1800 1801 llvm::OpenMPIRBuilder::InsertPointTy allocaIP = 1802 findAllocaInsertPoint(builder, moduleTranslation); 1803 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 1804 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 1805 moduleTranslation.getOpenMPBuilder()->createTask( 1806 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(), 1807 moduleTranslation.lookupValue(taskOp.getFinal()), 1808 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds, 1809 taskOp.getMergeable(), 1810 moduleTranslation.lookupValue(taskOp.getEventHandle()), 1811 moduleTranslation.lookupValue(taskOp.getPriority())); 1812 1813 if (failed(handleError(afterIP, *taskOp))) 1814 return failure(); 1815 1816 builder.restoreIP(*afterIP); 1817 return success(); 1818 } 1819 1820 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder. 1821 static LogicalResult 1822 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, 1823 LLVM::ModuleTranslation &moduleTranslation) { 1824 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 1825 if (failed(checkImplementationStatus(*tgOp))) 1826 return failure(); 1827 1828 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { 1829 builder.restoreIP(codegenIP); 1830 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region", 1831 builder, moduleTranslation) 1832 .takeError(); 1833 }; 1834 1835 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); 1836 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 1837 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 1838 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP, 1839 bodyCB); 1840 1841 if (failed(handleError(afterIP, *tgOp))) 1842 return failure(); 1843 1844 builder.restoreIP(*afterIP); 1845 return success(); 1846 } 1847 1848 static LogicalResult 1849 convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, 1850 LLVM::ModuleTranslation &moduleTranslation) { 1851 if (failed(checkImplementationStatus(*twOp))) 1852 return failure(); 1853 1854 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP()); 1855 return success(); 1856 } 1857 1858 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder. 1859 static LogicalResult 1860 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, 1861 LLVM::ModuleTranslation &moduleTranslation) { 1862 auto wsloopOp = cast<omp::WsloopOp>(opInst); 1863 if (failed(checkImplementationStatus(opInst))) 1864 return failure(); 1865 1866 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop()); 1867 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref()); 1868 assert(isByRef.size() == wsloopOp.getNumReductionVars()); 1869 1870 // Static is the default. 1871 auto schedule = 1872 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static); 1873 1874 // Find the loop configuration. 1875 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]); 1876 llvm::Type *ivType = step->getType(); 1877 llvm::Value *chunk = nullptr; 1878 if (wsloopOp.getScheduleChunk()) { 1879 llvm::Value *chunkVar = 1880 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk()); 1881 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType); 1882 } 1883 1884 MutableArrayRef<BlockArgument> privateBlockArgs = 1885 cast<omp::BlockArgOpenMPOpInterface>(*wsloopOp).getPrivateBlockArgs(); 1886 SmallVector<mlir::Value> mlirPrivateVars; 1887 SmallVector<llvm::Value *> llvmPrivateVars; 1888 SmallVector<omp::PrivateClauseOp> privateDecls; 1889 mlirPrivateVars.reserve(privateBlockArgs.size()); 1890 llvmPrivateVars.reserve(privateBlockArgs.size()); 1891 collectPrivatizationDecls(wsloopOp, privateDecls); 1892 1893 for (mlir::Value privateVar : wsloopOp.getPrivateVars()) 1894 mlirPrivateVars.push_back(privateVar); 1895 1896 SmallVector<omp::DeclareReductionOp> reductionDecls; 1897 collectReductionDecls(wsloopOp, reductionDecls); 1898 llvm::OpenMPIRBuilder::InsertPointTy allocaIP = 1899 findAllocaInsertPoint(builder, moduleTranslation); 1900 1901 SmallVector<llvm::Value *> privateReductionVariables( 1902 wsloopOp.getNumReductionVars()); 1903 1904 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars( 1905 builder, moduleTranslation, privateBlockArgs, privateDecls, 1906 mlirPrivateVars, llvmPrivateVars, allocaIP); 1907 if (handleError(afterAllocas, opInst).failed()) 1908 return failure(); 1909 1910 DenseMap<Value, llvm::Value *> reductionVariableMap; 1911 1912 MutableArrayRef<BlockArgument> reductionArgs = 1913 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs(); 1914 1915 SmallVector<DeferredStore> deferredStores; 1916 1917 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder, 1918 moduleTranslation, allocaIP, reductionDecls, 1919 privateReductionVariables, reductionVariableMap, 1920 deferredStores, isByRef))) 1921 return failure(); 1922 1923 if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars, 1924 llvmPrivateVars, privateDecls, 1925 afterAllocas.get()))) 1926 return failure(); 1927 1928 assert(afterAllocas.get()->getSinglePredecessor()); 1929 if (failed(initReductionVars(wsloopOp, reductionArgs, builder, 1930 moduleTranslation, 1931 afterAllocas.get()->getSinglePredecessor(), 1932 reductionDecls, privateReductionVariables, 1933 reductionVariableMap, isByRef, deferredStores))) 1934 return failure(); 1935 1936 // TODO: Replace this with proper composite translation support. 1937 // Currently, all nested wrappers are ignored, so 'do/for simd' will be 1938 // treated the same as a standalone 'do/for'. This is allowed by the spec, 1939 // since it's equivalent to always using a SIMD length of 1. 1940 if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation))) 1941 return failure(); 1942 1943 // Store the mapping between reduction variables and their private copies on 1944 // ModuleTranslation stack. It can be then recovered when translating 1945 // omp.reduce operations in a separate call. 1946 LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard( 1947 moduleTranslation, reductionVariableMap); 1948 1949 // Set up the source location value for OpenMP runtime. 1950 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 1951 1952 // Generator of the canonical loop body. 1953 SmallVector<llvm::CanonicalLoopInfo *> loopInfos; 1954 SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints; 1955 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, 1956 llvm::Value *iv) -> llvm::Error { 1957 // Make sure further conversions know about the induction variable. 1958 moduleTranslation.mapValue( 1959 loopOp.getRegion().front().getArgument(loopInfos.size()), iv); 1960 1961 // Capture the body insertion point for use in nested loops. BodyIP of the 1962 // CanonicalLoopInfo always points to the beginning of the entry block of 1963 // the body. 1964 bodyInsertPoints.push_back(ip); 1965 1966 if (loopInfos.size() != loopOp.getNumLoops() - 1) 1967 return llvm::Error::success(); 1968 1969 // Convert the body of the loop. 1970 builder.restoreIP(ip); 1971 return convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder, 1972 moduleTranslation) 1973 .takeError(); 1974 }; 1975 1976 // Delegate actual loop construction to the OpenMP IRBuilder. 1977 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF 1978 // loop, i.e. it has a positive step, uses signed integer semantics. 1979 // Reconsider this code when the nested loop operation clearly supports more 1980 // cases. 1981 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 1982 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) { 1983 llvm::Value *lowerBound = 1984 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]); 1985 llvm::Value *upperBound = 1986 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]); 1987 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]); 1988 1989 // Make sure loop trip count are emitted in the preheader of the outermost 1990 // loop at the latest so that they are all available for the new collapsed 1991 // loop will be created below. 1992 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc; 1993 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP; 1994 if (i != 0) { 1995 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back()); 1996 computeIP = loopInfos.front()->getPreheaderIP(); 1997 } 1998 1999 llvm::Expected<llvm::CanonicalLoopInfo *> loopResult = 2000 ompBuilder->createCanonicalLoop( 2001 loc, bodyGen, lowerBound, upperBound, step, 2002 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP); 2003 2004 if (failed(handleError(loopResult, *loopOp))) 2005 return failure(); 2006 2007 loopInfos.push_back(*loopResult); 2008 } 2009 2010 // Collapse loops. Store the insertion point because LoopInfos may get 2011 // invalidated. 2012 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP(); 2013 llvm::CanonicalLoopInfo *loopInfo = 2014 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}); 2015 2016 allocaIP = findAllocaInsertPoint(builder, moduleTranslation); 2017 2018 // TODO: Handle doacross loops when the ordered clause has a parameter. 2019 bool isOrdered = wsloopOp.getOrdered().has_value(); 2020 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod(); 2021 bool isSimd = wsloopOp.getScheduleSimd(); 2022 2023 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = 2024 ompBuilder->applyWorkshareLoop( 2025 ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(), 2026 convertToScheduleKind(schedule), chunk, isSimd, 2027 scheduleMod == omp::ScheduleModifier::monotonic, 2028 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered); 2029 2030 if (failed(handleError(wsloopIP, opInst))) 2031 return failure(); 2032 2033 // Continue building IR after the loop. Note that the LoopInfo returned by 2034 // `collapseLoops` points inside the outermost loop and is intended for 2035 // potential further loop transformations. Use the insertion point stored 2036 // before collapsing loops instead. 2037 builder.restoreIP(afterIP); 2038 2039 // Process the reductions if required. 2040 if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation, 2041 allocaIP, reductionDecls, 2042 privateReductionVariables, isByRef))) 2043 return failure(); 2044 2045 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(), 2046 llvmPrivateVars, privateDecls); 2047 } 2048 2049 /// Converts the OpenMP parallel operation to LLVM IR. 2050 static LogicalResult 2051 convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, 2052 LLVM::ModuleTranslation &moduleTranslation) { 2053 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 2054 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref()); 2055 assert(isByRef.size() == opInst.getNumReductionVars()); 2056 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 2057 2058 if (failed(checkImplementationStatus(*opInst))) 2059 return failure(); 2060 2061 // Collect delayed privatization declarations 2062 MutableArrayRef<BlockArgument> privateBlockArgs = 2063 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getPrivateBlockArgs(); 2064 SmallVector<mlir::Value> mlirPrivateVars; 2065 SmallVector<llvm::Value *> llvmPrivateVars; 2066 SmallVector<omp::PrivateClauseOp> privateDecls; 2067 mlirPrivateVars.reserve(privateBlockArgs.size()); 2068 llvmPrivateVars.reserve(privateBlockArgs.size()); 2069 collectPrivatizationDecls(opInst, privateDecls); 2070 for (mlir::Value privateVar : opInst.getPrivateVars()) 2071 mlirPrivateVars.push_back(privateVar); 2072 2073 // Collect reduction declarations 2074 SmallVector<omp::DeclareReductionOp> reductionDecls; 2075 collectReductionDecls(opInst, reductionDecls); 2076 SmallVector<llvm::Value *> privateReductionVariables( 2077 opInst.getNumReductionVars()); 2078 SmallVector<DeferredStore> deferredStores; 2079 2080 auto bodyGenCB = [&](InsertPointTy allocaIP, 2081 InsertPointTy codeGenIP) -> llvm::Error { 2082 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars( 2083 builder, moduleTranslation, privateBlockArgs, privateDecls, 2084 mlirPrivateVars, llvmPrivateVars, allocaIP); 2085 if (handleError(afterAllocas, *opInst).failed()) 2086 return llvm::make_error<PreviouslyReportedError>(); 2087 2088 // Allocate reduction vars 2089 DenseMap<Value, llvm::Value *> reductionVariableMap; 2090 2091 MutableArrayRef<BlockArgument> reductionArgs = 2092 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs(); 2093 2094 allocaIP = 2095 InsertPointTy(allocaIP.getBlock(), 2096 allocaIP.getBlock()->getTerminator()->getIterator()); 2097 2098 if (failed(allocReductionVars( 2099 opInst, reductionArgs, builder, moduleTranslation, allocaIP, 2100 reductionDecls, privateReductionVariables, reductionVariableMap, 2101 deferredStores, isByRef))) 2102 return llvm::make_error<PreviouslyReportedError>(); 2103 2104 if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars, 2105 llvmPrivateVars, privateDecls, 2106 afterAllocas.get()))) 2107 return llvm::make_error<PreviouslyReportedError>(); 2108 2109 assert(afterAllocas.get()->getSinglePredecessor()); 2110 builder.restoreIP(codeGenIP); 2111 2112 if (failed( 2113 initReductionVars(opInst, reductionArgs, builder, moduleTranslation, 2114 afterAllocas.get()->getSinglePredecessor(), 2115 reductionDecls, privateReductionVariables, 2116 reductionVariableMap, isByRef, deferredStores))) 2117 return llvm::make_error<PreviouslyReportedError>(); 2118 2119 // Store the mapping between reduction variables and their private copies on 2120 // ModuleTranslation stack. It can be then recovered when translating 2121 // omp.reduce operations in a separate call. 2122 LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard( 2123 moduleTranslation, reductionVariableMap); 2124 2125 // Save the alloca insertion point on ModuleTranslation stack for use in 2126 // nested regions. 2127 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame( 2128 moduleTranslation, allocaIP); 2129 2130 // ParallelOp has only one region associated with it. 2131 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions( 2132 opInst.getRegion(), "omp.par.region", builder, moduleTranslation); 2133 if (!regionBlock) 2134 return regionBlock.takeError(); 2135 2136 // Process the reductions if required. 2137 if (opInst.getNumReductionVars() > 0) { 2138 // Collect reduction info 2139 SmallVector<OwningReductionGen> owningReductionGens; 2140 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens; 2141 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos; 2142 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls, 2143 owningReductionGens, owningAtomicReductionGens, 2144 privateReductionVariables, reductionInfos); 2145 2146 // Move to region cont block 2147 builder.SetInsertPoint((*regionBlock)->getTerminator()); 2148 2149 // Generate reductions from info 2150 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable(); 2151 builder.SetInsertPoint(tempTerminator); 2152 2153 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint = 2154 ompBuilder->createReductions(builder.saveIP(), allocaIP, 2155 reductionInfos, isByRef, false); 2156 if (!contInsertPoint) 2157 return contInsertPoint.takeError(); 2158 2159 if (!contInsertPoint->getBlock()) 2160 return llvm::make_error<PreviouslyReportedError>(); 2161 2162 tempTerminator->eraseFromParent(); 2163 builder.restoreIP(*contInsertPoint); 2164 } 2165 return llvm::Error::success(); 2166 }; 2167 2168 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP, 2169 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) { 2170 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in 2171 // bodyGenCB. 2172 replVal = &val; 2173 return codeGenIP; 2174 }; 2175 2176 // TODO: Perform finalization actions for variables. This has to be 2177 // called for variables which have destructors/finalizers. 2178 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error { 2179 InsertPointTy oldIP = builder.saveIP(); 2180 builder.restoreIP(codeGenIP); 2181 2182 // if the reduction has a cleanup region, inline it here to finalize the 2183 // reduction variables 2184 SmallVector<Region *> reductionCleanupRegions; 2185 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions), 2186 [](omp::DeclareReductionOp reductionDecl) { 2187 return &reductionDecl.getCleanupRegion(); 2188 }); 2189 if (failed(inlineOmpRegionCleanup( 2190 reductionCleanupRegions, privateReductionVariables, 2191 moduleTranslation, builder, "omp.reduction.cleanup"))) 2192 return llvm::createStringError( 2193 "failed to inline `cleanup` region of `omp.declare_reduction`"); 2194 2195 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(), 2196 llvmPrivateVars, privateDecls))) 2197 return llvm::make_error<PreviouslyReportedError>(); 2198 2199 builder.restoreIP(oldIP); 2200 return llvm::Error::success(); 2201 }; 2202 2203 llvm::Value *ifCond = nullptr; 2204 if (auto ifVar = opInst.getIfExpr()) 2205 ifCond = moduleTranslation.lookupValue(ifVar); 2206 llvm::Value *numThreads = nullptr; 2207 if (auto numThreadsVar = opInst.getNumThreads()) 2208 numThreads = moduleTranslation.lookupValue(numThreadsVar); 2209 auto pbKind = llvm::omp::OMP_PROC_BIND_default; 2210 if (auto bind = opInst.getProcBindKind()) 2211 pbKind = getProcBindKind(*bind); 2212 // TODO: Is the Parallel construct cancellable? 2213 bool isCancellable = false; 2214 2215 llvm::OpenMPIRBuilder::InsertPointTy allocaIP = 2216 findAllocaInsertPoint(builder, moduleTranslation); 2217 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 2218 2219 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 2220 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB, 2221 ifCond, numThreads, pbKind, isCancellable); 2222 2223 if (failed(handleError(afterIP, *opInst))) 2224 return failure(); 2225 2226 builder.restoreIP(*afterIP); 2227 return success(); 2228 } 2229 2230 /// Convert Order attribute to llvm::omp::OrderKind. 2231 static llvm::omp::OrderKind 2232 convertOrderKind(std::optional<omp::ClauseOrderKind> o) { 2233 if (!o) 2234 return llvm::omp::OrderKind::OMP_ORDER_unknown; 2235 switch (*o) { 2236 case omp::ClauseOrderKind::Concurrent: 2237 return llvm::omp::OrderKind::OMP_ORDER_concurrent; 2238 } 2239 llvm_unreachable("Unknown ClauseOrderKind kind"); 2240 } 2241 2242 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder. 2243 static LogicalResult 2244 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, 2245 LLVM::ModuleTranslation &moduleTranslation) { 2246 auto simdOp = cast<omp::SimdOp>(opInst); 2247 auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop()); 2248 2249 if (failed(checkImplementationStatus(opInst))) 2250 return failure(); 2251 2252 MutableArrayRef<BlockArgument> privateBlockArgs = 2253 cast<omp::BlockArgOpenMPOpInterface>(*simdOp).getPrivateBlockArgs(); 2254 SmallVector<mlir::Value> mlirPrivateVars; 2255 SmallVector<llvm::Value *> llvmPrivateVars; 2256 SmallVector<omp::PrivateClauseOp> privateDecls; 2257 mlirPrivateVars.reserve(privateBlockArgs.size()); 2258 llvmPrivateVars.reserve(privateBlockArgs.size()); 2259 collectPrivatizationDecls(simdOp, privateDecls); 2260 2261 for (mlir::Value privateVar : simdOp.getPrivateVars()) 2262 mlirPrivateVars.push_back(privateVar); 2263 2264 llvm::OpenMPIRBuilder::InsertPointTy allocaIP = 2265 findAllocaInsertPoint(builder, moduleTranslation); 2266 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 2267 2268 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars( 2269 builder, moduleTranslation, privateBlockArgs, privateDecls, 2270 mlirPrivateVars, llvmPrivateVars, allocaIP); 2271 if (handleError(afterAllocas, opInst).failed()) 2272 return failure(); 2273 2274 // Generator of the canonical loop body. 2275 SmallVector<llvm::CanonicalLoopInfo *> loopInfos; 2276 SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints; 2277 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, 2278 llvm::Value *iv) -> llvm::Error { 2279 // Make sure further conversions know about the induction variable. 2280 moduleTranslation.mapValue( 2281 loopOp.getRegion().front().getArgument(loopInfos.size()), iv); 2282 2283 // Capture the body insertion point for use in nested loops. BodyIP of the 2284 // CanonicalLoopInfo always points to the beginning of the entry block of 2285 // the body. 2286 bodyInsertPoints.push_back(ip); 2287 2288 if (loopInfos.size() != loopOp.getNumLoops() - 1) 2289 return llvm::Error::success(); 2290 2291 // Convert the body of the loop. 2292 builder.restoreIP(ip); 2293 return convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder, 2294 moduleTranslation) 2295 .takeError(); 2296 }; 2297 2298 // Delegate actual loop construction to the OpenMP IRBuilder. 2299 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF 2300 // loop, i.e. it has a positive step, uses signed integer semantics. 2301 // Reconsider this code when the nested loop operation clearly supports more 2302 // cases. 2303 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 2304 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) { 2305 llvm::Value *lowerBound = 2306 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]); 2307 llvm::Value *upperBound = 2308 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]); 2309 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]); 2310 2311 // Make sure loop trip count are emitted in the preheader of the outermost 2312 // loop at the latest so that they are all available for the new collapsed 2313 // loop will be created below. 2314 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc; 2315 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP; 2316 if (i != 0) { 2317 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(), 2318 ompLoc.DL); 2319 computeIP = loopInfos.front()->getPreheaderIP(); 2320 } 2321 2322 llvm::Expected<llvm::CanonicalLoopInfo *> loopResult = 2323 ompBuilder->createCanonicalLoop( 2324 loc, bodyGen, lowerBound, upperBound, step, 2325 /*IsSigned=*/true, /*InclusiveStop=*/true, computeIP); 2326 2327 if (failed(handleError(loopResult, *loopOp))) 2328 return failure(); 2329 2330 loopInfos.push_back(*loopResult); 2331 } 2332 2333 // Collapse loops. 2334 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP(); 2335 llvm::CanonicalLoopInfo *loopInfo = 2336 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}); 2337 2338 llvm::ConstantInt *simdlen = nullptr; 2339 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen()) 2340 simdlen = builder.getInt64(simdlenVar.value()); 2341 2342 llvm::ConstantInt *safelen = nullptr; 2343 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen()) 2344 safelen = builder.getInt64(safelenVar.value()); 2345 2346 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars; 2347 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder()); 2348 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock(); 2349 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments(); 2350 mlir::OperandRange operands = simdOp.getAlignedVars(); 2351 for (size_t i = 0; i < operands.size(); ++i) { 2352 llvm::Value *alignment = nullptr; 2353 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]); 2354 llvm::Type *ty = llvmVal->getType(); 2355 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) { 2356 alignment = builder.getInt64(intAttr.getInt()); 2357 assert(ty->isPointerTy() && "Invalid type for aligned variable"); 2358 assert(alignment && "Invalid alignment value"); 2359 auto curInsert = builder.saveIP(); 2360 builder.SetInsertPoint(sourceBlock->getTerminator()); 2361 llvmVal = builder.CreateLoad(ty, llvmVal); 2362 builder.restoreIP(curInsert); 2363 alignedVars[llvmVal] = alignment; 2364 } 2365 } 2366 ompBuilder->applySimd(loopInfo, alignedVars, 2367 simdOp.getIfExpr() 2368 ? moduleTranslation.lookupValue(simdOp.getIfExpr()) 2369 : nullptr, 2370 order, simdlen, safelen); 2371 2372 builder.restoreIP(afterIP); 2373 2374 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(), 2375 llvmPrivateVars, privateDecls); 2376 } 2377 2378 /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering. 2379 static llvm::AtomicOrdering 2380 convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) { 2381 if (!ao) 2382 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering 2383 2384 switch (*ao) { 2385 case omp::ClauseMemoryOrderKind::Seq_cst: 2386 return llvm::AtomicOrdering::SequentiallyConsistent; 2387 case omp::ClauseMemoryOrderKind::Acq_rel: 2388 return llvm::AtomicOrdering::AcquireRelease; 2389 case omp::ClauseMemoryOrderKind::Acquire: 2390 return llvm::AtomicOrdering::Acquire; 2391 case omp::ClauseMemoryOrderKind::Release: 2392 return llvm::AtomicOrdering::Release; 2393 case omp::ClauseMemoryOrderKind::Relaxed: 2394 return llvm::AtomicOrdering::Monotonic; 2395 } 2396 llvm_unreachable("Unknown ClauseMemoryOrderKind kind"); 2397 } 2398 2399 /// Convert omp.atomic.read operation to LLVM IR. 2400 static LogicalResult 2401 convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, 2402 LLVM::ModuleTranslation &moduleTranslation) { 2403 auto readOp = cast<omp::AtomicReadOp>(opInst); 2404 if (failed(checkImplementationStatus(opInst))) 2405 return failure(); 2406 2407 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 2408 2409 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 2410 2411 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder()); 2412 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX()); 2413 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV()); 2414 2415 llvm::Type *elementType = 2416 moduleTranslation.convertType(readOp.getElementType()); 2417 2418 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false}; 2419 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false}; 2420 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO)); 2421 return success(); 2422 } 2423 2424 /// Converts an omp.atomic.write operation to LLVM IR. 2425 static LogicalResult 2426 convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, 2427 LLVM::ModuleTranslation &moduleTranslation) { 2428 auto writeOp = cast<omp::AtomicWriteOp>(opInst); 2429 if (failed(checkImplementationStatus(opInst))) 2430 return failure(); 2431 2432 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 2433 2434 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 2435 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder()); 2436 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr()); 2437 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX()); 2438 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType()); 2439 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false, 2440 /*isVolatile=*/false}; 2441 builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao)); 2442 return success(); 2443 } 2444 2445 /// Converts an LLVM dialect binary operation to the corresponding enum value 2446 /// for `atomicrmw` supported binary operation. 2447 llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) { 2448 return llvm::TypeSwitch<Operation *, llvm::AtomicRMWInst::BinOp>(&op) 2449 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; }) 2450 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; }) 2451 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; }) 2452 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; }) 2453 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; }) 2454 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; }) 2455 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; }) 2456 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; }) 2457 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; }) 2458 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP); 2459 } 2460 2461 /// Converts an OpenMP atomic update operation using OpenMPIRBuilder. 2462 static LogicalResult 2463 convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, 2464 llvm::IRBuilderBase &builder, 2465 LLVM::ModuleTranslation &moduleTranslation) { 2466 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 2467 if (failed(checkImplementationStatus(*opInst))) 2468 return failure(); 2469 2470 // Convert values and types. 2471 auto &innerOpList = opInst.getRegion().front().getOperations(); 2472 bool isXBinopExpr{false}; 2473 llvm::AtomicRMWInst::BinOp binop; 2474 mlir::Value mlirExpr; 2475 llvm::Value *llvmExpr = nullptr; 2476 llvm::Value *llvmX = nullptr; 2477 llvm::Type *llvmXElementType = nullptr; 2478 if (innerOpList.size() == 2) { 2479 // The two operations here are the update and the terminator. 2480 // Since we can identify the update operation, there is a possibility 2481 // that we can generate the atomicrmw instruction. 2482 mlir::Operation &innerOp = *opInst.getRegion().front().begin(); 2483 if (!llvm::is_contained(innerOp.getOperands(), 2484 opInst.getRegion().getArgument(0))) { 2485 return opInst.emitError("no atomic update operation with region argument" 2486 " as operand found inside atomic.update region"); 2487 } 2488 binop = convertBinOpToAtomic(innerOp); 2489 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0); 2490 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0)); 2491 llvmExpr = moduleTranslation.lookupValue(mlirExpr); 2492 } else { 2493 // Since the update region includes more than one operation 2494 // we will resort to generating a cmpxchg loop. 2495 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP; 2496 } 2497 llvmX = moduleTranslation.lookupValue(opInst.getX()); 2498 llvmXElementType = moduleTranslation.convertType( 2499 opInst.getRegion().getArgument(0).getType()); 2500 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType, 2501 /*isSigned=*/false, 2502 /*isVolatile=*/false}; 2503 2504 llvm::AtomicOrdering atomicOrdering = 2505 convertAtomicOrdering(opInst.getMemoryOrder()); 2506 2507 // Generate update code. 2508 auto updateFn = 2509 [&opInst, &moduleTranslation]( 2510 llvm::Value *atomicx, 2511 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> { 2512 Block &bb = *opInst.getRegion().begin(); 2513 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx); 2514 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock()); 2515 if (failed(moduleTranslation.convertBlock(bb, true, builder))) 2516 return llvm::make_error<PreviouslyReportedError>(); 2517 2518 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator()); 2519 assert(yieldop && yieldop.getResults().size() == 1 && 2520 "terminator must be omp.yield op and it must have exactly one " 2521 "argument"); 2522 return moduleTranslation.lookupValue(yieldop.getResults()[0]); 2523 }; 2524 2525 // Handle ambiguous alloca, if any. 2526 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation); 2527 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 2528 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 2529 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr, 2530 atomicOrdering, binop, updateFn, 2531 isXBinopExpr); 2532 2533 if (failed(handleError(afterIP, *opInst))) 2534 return failure(); 2535 2536 builder.restoreIP(*afterIP); 2537 return success(); 2538 } 2539 2540 static LogicalResult 2541 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, 2542 llvm::IRBuilderBase &builder, 2543 LLVM::ModuleTranslation &moduleTranslation) { 2544 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 2545 if (failed(checkImplementationStatus(*atomicCaptureOp))) 2546 return failure(); 2547 2548 mlir::Value mlirExpr; 2549 bool isXBinopExpr = false, isPostfixUpdate = false; 2550 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP; 2551 2552 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp(); 2553 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp(); 2554 2555 assert((atomicUpdateOp || atomicWriteOp) && 2556 "internal op must be an atomic.update or atomic.write op"); 2557 2558 if (atomicWriteOp) { 2559 isPostfixUpdate = true; 2560 mlirExpr = atomicWriteOp.getExpr(); 2561 } else { 2562 isPostfixUpdate = atomicCaptureOp.getSecondOp() == 2563 atomicCaptureOp.getAtomicUpdateOp().getOperation(); 2564 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations(); 2565 // Find the binary update operation that uses the region argument 2566 // and get the expression to update 2567 if (innerOpList.size() == 2) { 2568 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin(); 2569 if (!llvm::is_contained(innerOp.getOperands(), 2570 atomicUpdateOp.getRegion().getArgument(0))) { 2571 return atomicUpdateOp.emitError( 2572 "no atomic update operation with region argument" 2573 " as operand found inside atomic.update region"); 2574 } 2575 binop = convertBinOpToAtomic(innerOp); 2576 isXBinopExpr = 2577 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0); 2578 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0)); 2579 } else { 2580 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP; 2581 } 2582 } 2583 2584 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr); 2585 llvm::Value *llvmX = 2586 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX()); 2587 llvm::Value *llvmV = 2588 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV()); 2589 llvm::Type *llvmXElementType = moduleTranslation.convertType( 2590 atomicCaptureOp.getAtomicReadOp().getElementType()); 2591 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType, 2592 /*isSigned=*/false, 2593 /*isVolatile=*/false}; 2594 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType, 2595 /*isSigned=*/false, 2596 /*isVolatile=*/false}; 2597 2598 llvm::AtomicOrdering atomicOrdering = 2599 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder()); 2600 2601 auto updateFn = 2602 [&](llvm::Value *atomicx, 2603 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> { 2604 if (atomicWriteOp) 2605 return moduleTranslation.lookupValue(atomicWriteOp.getExpr()); 2606 Block &bb = *atomicUpdateOp.getRegion().begin(); 2607 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(), 2608 atomicx); 2609 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock()); 2610 if (failed(moduleTranslation.convertBlock(bb, true, builder))) 2611 return llvm::make_error<PreviouslyReportedError>(); 2612 2613 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator()); 2614 assert(yieldop && yieldop.getResults().size() == 1 && 2615 "terminator must be omp.yield op and it must have exactly one " 2616 "argument"); 2617 return moduleTranslation.lookupValue(yieldop.getResults()[0]); 2618 }; 2619 2620 // Handle ambiguous alloca, if any. 2621 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation); 2622 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 2623 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 2624 ompBuilder->createAtomicCapture( 2625 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering, 2626 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr); 2627 2628 if (failed(handleError(afterIP, *atomicCaptureOp))) 2629 return failure(); 2630 2631 builder.restoreIP(*afterIP); 2632 return success(); 2633 } 2634 2635 /// Converts an OpenMP Threadprivate operation into LLVM IR using 2636 /// OpenMPIRBuilder. 2637 static LogicalResult 2638 convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, 2639 LLVM::ModuleTranslation &moduleTranslation) { 2640 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 2641 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 2642 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst); 2643 2644 if (failed(checkImplementationStatus(opInst))) 2645 return failure(); 2646 2647 Value symAddr = threadprivateOp.getSymAddr(); 2648 auto *symOp = symAddr.getDefiningOp(); 2649 2650 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp)) 2651 symOp = asCast.getOperand().getDefiningOp(); 2652 2653 if (!isa<LLVM::AddressOfOp>(symOp)) 2654 return opInst.emitError("Addressing symbol not found"); 2655 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp); 2656 2657 LLVM::GlobalOp global = 2658 addressOfOp.getGlobal(moduleTranslation.symbolTable()); 2659 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global); 2660 2661 if (!ompBuilder->Config.isTargetDevice()) { 2662 llvm::Type *type = globalValue->getValueType(); 2663 llvm::TypeSize typeSize = 2664 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize( 2665 type); 2666 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue()); 2667 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate( 2668 ompLoc, globalValue, size, global.getSymName() + ".cache"); 2669 moduleTranslation.mapValue(opInst.getResult(0), callInst); 2670 } else { 2671 moduleTranslation.mapValue(opInst.getResult(0), globalValue); 2672 } 2673 2674 return success(); 2675 } 2676 2677 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind 2678 convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) { 2679 switch (deviceClause) { 2680 case mlir::omp::DeclareTargetDeviceType::host: 2681 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost; 2682 break; 2683 case mlir::omp::DeclareTargetDeviceType::nohost: 2684 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost; 2685 break; 2686 case mlir::omp::DeclareTargetDeviceType::any: 2687 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny; 2688 break; 2689 } 2690 llvm_unreachable("unhandled device clause"); 2691 } 2692 2693 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind 2694 convertToCaptureClauseKind( 2695 mlir::omp::DeclareTargetCaptureClause captureClause) { 2696 switch (captureClause) { 2697 case mlir::omp::DeclareTargetCaptureClause::to: 2698 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo; 2699 case mlir::omp::DeclareTargetCaptureClause::link: 2700 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink; 2701 case mlir::omp::DeclareTargetCaptureClause::enter: 2702 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter; 2703 } 2704 llvm_unreachable("unhandled capture clause"); 2705 } 2706 2707 static llvm::SmallString<64> 2708 getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, 2709 llvm::OpenMPIRBuilder &ompBuilder) { 2710 llvm::SmallString<64> suffix; 2711 llvm::raw_svector_ostream os(suffix); 2712 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) { 2713 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>(); 2714 auto fileInfoCallBack = [&loc]() { 2715 return std::pair<std::string, uint64_t>( 2716 llvm::StringRef(loc.getFilename()), loc.getLine()); 2717 }; 2718 2719 os << llvm::format( 2720 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID); 2721 } 2722 os << "_decl_tgt_ref_ptr"; 2723 2724 return suffix; 2725 } 2726 2727 static bool isDeclareTargetLink(mlir::Value value) { 2728 if (auto addressOfOp = 2729 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) { 2730 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>(); 2731 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName()); 2732 if (auto declareTargetGlobal = 2733 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp)) 2734 if (declareTargetGlobal.getDeclareTargetCaptureClause() == 2735 mlir::omp::DeclareTargetCaptureClause::link) 2736 return true; 2737 } 2738 return false; 2739 } 2740 2741 // Returns the reference pointer generated by the lowering of the declare target 2742 // operation in cases where the link clause is used or the to clause is used in 2743 // USM mode. 2744 static llvm::Value * 2745 getRefPtrIfDeclareTarget(mlir::Value value, 2746 LLVM::ModuleTranslation &moduleTranslation) { 2747 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 2748 2749 // An easier way to do this may just be to keep track of any pointer 2750 // references and their mapping to their respective operation 2751 if (auto addressOfOp = 2752 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) { 2753 if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>( 2754 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol( 2755 addressOfOp.getGlobalName()))) { 2756 2757 if (auto declareTargetGlobal = 2758 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( 2759 gOp.getOperation())) { 2760 2761 // In this case, we must utilise the reference pointer generated by the 2762 // declare target operation, similar to Clang 2763 if ((declareTargetGlobal.getDeclareTargetCaptureClause() == 2764 mlir::omp::DeclareTargetCaptureClause::link) || 2765 (declareTargetGlobal.getDeclareTargetCaptureClause() == 2766 mlir::omp::DeclareTargetCaptureClause::to && 2767 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) { 2768 llvm::SmallString<64> suffix = 2769 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder); 2770 2771 if (gOp.getSymName().contains(suffix)) 2772 return moduleTranslation.getLLVMModule()->getNamedValue( 2773 gOp.getSymName()); 2774 2775 return moduleTranslation.getLLVMModule()->getNamedValue( 2776 (gOp.getSymName().str() + suffix.str()).str()); 2777 } 2778 } 2779 } 2780 } 2781 2782 return nullptr; 2783 } 2784 2785 namespace { 2786 // A small helper structure to contain data gathered 2787 // for map lowering and coalese it into one area and 2788 // avoiding extra computations such as searches in the 2789 // llvm module for lowered mapped variables or checking 2790 // if something is declare target (and retrieving the 2791 // value) more than neccessary. 2792 struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy { 2793 llvm::SmallVector<bool, 4> IsDeclareTarget; 2794 llvm::SmallVector<bool, 4> IsAMember; 2795 // Identify if mapping was added by mapClause or use_device clauses. 2796 llvm::SmallVector<bool, 4> IsAMapping; 2797 llvm::SmallVector<mlir::Operation *, 4> MapClause; 2798 llvm::SmallVector<llvm::Value *, 4> OriginalValue; 2799 // Stripped off array/pointer to get the underlying 2800 // element type 2801 llvm::SmallVector<llvm::Type *, 4> BaseType; 2802 2803 /// Append arrays in \a CurInfo. 2804 void append(MapInfoData &CurInfo) { 2805 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(), 2806 CurInfo.IsDeclareTarget.end()); 2807 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end()); 2808 OriginalValue.append(CurInfo.OriginalValue.begin(), 2809 CurInfo.OriginalValue.end()); 2810 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end()); 2811 llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo); 2812 } 2813 }; 2814 } // namespace 2815 2816 uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) { 2817 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>( 2818 arrTy.getElementType())) 2819 return getArrayElementSizeInBits(nestedArrTy, dl); 2820 return dl.getTypeSizeInBits(arrTy.getElementType()); 2821 } 2822 2823 // This function calculates the size to be offloaded for a specified type, given 2824 // its associated map clause (which can contain bounds information which affects 2825 // the total size), this size is calculated based on the underlying element type 2826 // e.g. given a 1-D array of ints, we will calculate the size from the integer 2827 // type * number of elements in the array. This size can be used in other 2828 // calculations but is ultimately used as an argument to the OpenMP runtimes 2829 // kernel argument structure which is generated through the combinedInfo data 2830 // structures. 2831 // This function is somewhat equivalent to Clang's getExprTypeSize inside of 2832 // CGOpenMPRuntime.cpp. 2833 llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type, 2834 Operation *clauseOp, llvm::Value *basePointer, 2835 llvm::Type *baseType, llvm::IRBuilderBase &builder, 2836 LLVM::ModuleTranslation &moduleTranslation) { 2837 if (auto memberClause = 2838 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) { 2839 // This calculates the size to transfer based on bounds and the underlying 2840 // element type, provided bounds have been specified (Fortran 2841 // pointers/allocatables/target and arrays that have sections specified fall 2842 // into this as well). 2843 if (!memberClause.getBounds().empty()) { 2844 llvm::Value *elementCount = builder.getInt64(1); 2845 for (auto bounds : memberClause.getBounds()) { 2846 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>( 2847 bounds.getDefiningOp())) { 2848 // The below calculation for the size to be mapped calculated from the 2849 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we 2850 // multiply by the underlying element types byte size to get the full 2851 // size to be offloaded based on the bounds 2852 elementCount = builder.CreateMul( 2853 elementCount, 2854 builder.CreateAdd( 2855 builder.CreateSub( 2856 moduleTranslation.lookupValue(boundOp.getUpperBound()), 2857 moduleTranslation.lookupValue(boundOp.getLowerBound())), 2858 builder.getInt64(1))); 2859 } 2860 } 2861 2862 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives 2863 // the size in inconsistent byte or bit format. 2864 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type); 2865 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type)) 2866 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl); 2867 2868 // The size in bytes x number of elements, the sizeInBytes stored is 2869 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's 2870 // size, so we do some on the fly runtime math to get the size in 2871 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need 2872 // some adjustment for members with more complex types. 2873 return builder.CreateMul(elementCount, 2874 builder.getInt64(underlyingTypeSzInBits / 8)); 2875 } 2876 } 2877 2878 return builder.getInt64(dl.getTypeSizeInBits(type) / 8); 2879 } 2880 2881 static void collectMapDataFromMapOperands( 2882 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars, 2883 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, 2884 llvm::IRBuilderBase &builder, const ArrayRef<Value> &useDevPtrOperands = {}, 2885 const ArrayRef<Value> &useDevAddrOperands = {}) { 2886 auto checkIsAMember = [](const auto &mapVars, auto mapOp) { 2887 // Check if this is a member mapping and correctly assign that it is, if 2888 // it is a member of a larger object. 2889 // TODO: Need better handling of members, and distinguishing of members 2890 // that are implicitly allocated on device vs explicitly passed in as 2891 // arguments. 2892 // TODO: May require some further additions to support nested record 2893 // types, i.e. member maps that can have member maps. 2894 for (Value mapValue : mapVars) { 2895 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp()); 2896 for (auto member : map.getMembers()) 2897 if (member == mapOp) 2898 return true; 2899 } 2900 return false; 2901 }; 2902 2903 // Process MapOperands 2904 for (Value mapValue : mapVars) { 2905 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp()); 2906 Value offloadPtr = 2907 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr(); 2908 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr)); 2909 mapData.Pointers.push_back(mapData.OriginalValue.back()); 2910 2911 if (llvm::Value *refPtr = 2912 getRefPtrIfDeclareTarget(offloadPtr, 2913 moduleTranslation)) { // declare target 2914 mapData.IsDeclareTarget.push_back(true); 2915 mapData.BasePointers.push_back(refPtr); 2916 } else { // regular mapped variable 2917 mapData.IsDeclareTarget.push_back(false); 2918 mapData.BasePointers.push_back(mapData.OriginalValue.back()); 2919 } 2920 2921 mapData.BaseType.push_back( 2922 moduleTranslation.convertType(mapOp.getVarType())); 2923 mapData.Sizes.push_back( 2924 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(), 2925 mapData.BaseType.back(), builder, moduleTranslation)); 2926 mapData.MapClause.push_back(mapOp.getOperation()); 2927 mapData.Types.push_back( 2928 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value())); 2929 mapData.Names.push_back(LLVM::createMappingInformation( 2930 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); 2931 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None); 2932 mapData.IsAMapping.push_back(true); 2933 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp)); 2934 } 2935 2936 auto findMapInfo = [&mapData](llvm::Value *val, 2937 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) { 2938 unsigned index = 0; 2939 bool found = false; 2940 for (llvm::Value *basePtr : mapData.OriginalValue) { 2941 if (basePtr == val && mapData.IsAMapping[index]) { 2942 found = true; 2943 mapData.Types[index] |= 2944 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; 2945 mapData.DevicePointers[index] = devInfoTy; 2946 } 2947 index++; 2948 } 2949 return found; 2950 }; 2951 2952 // Process useDevPtr(Addr)Operands 2953 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands, 2954 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) { 2955 for (Value mapValue : useDevOperands) { 2956 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp()); 2957 Value offloadPtr = 2958 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr(); 2959 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr); 2960 2961 // Check if map info is already present for this entry. 2962 if (!findMapInfo(origValue, devInfoTy)) { 2963 mapData.OriginalValue.push_back(origValue); 2964 mapData.Pointers.push_back(mapData.OriginalValue.back()); 2965 mapData.IsDeclareTarget.push_back(false); 2966 mapData.BasePointers.push_back(mapData.OriginalValue.back()); 2967 mapData.BaseType.push_back( 2968 moduleTranslation.convertType(mapOp.getVarType())); 2969 mapData.Sizes.push_back(builder.getInt64(0)); 2970 mapData.MapClause.push_back(mapOp.getOperation()); 2971 mapData.Types.push_back( 2972 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM); 2973 mapData.Names.push_back(LLVM::createMappingInformation( 2974 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); 2975 mapData.DevicePointers.push_back(devInfoTy); 2976 mapData.IsAMapping.push_back(false); 2977 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp)); 2978 } 2979 } 2980 }; 2981 2982 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address); 2983 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer); 2984 } 2985 2986 static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) { 2987 auto *res = llvm::find(mapData.MapClause, memberOp); 2988 assert(res != mapData.MapClause.end() && 2989 "MapInfoOp for member not found in MapData, cannot return index"); 2990 return std::distance(mapData.MapClause.begin(), res); 2991 } 2992 2993 static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, 2994 bool first) { 2995 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); 2996 // Only 1 member has been mapped, we can return it. 2997 if (indexAttr.size() == 1) 2998 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()); 2999 3000 llvm::SmallVector<size_t> indices(indexAttr.size()); 3001 std::iota(indices.begin(), indices.end(), 0); 3002 3003 llvm::sort(indices.begin(), indices.end(), 3004 [&](const size_t a, const size_t b) { 3005 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); 3006 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); 3007 for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { 3008 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); 3009 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); 3010 3011 if (aIndex == bIndex) 3012 continue; 3013 3014 if (aIndex < bIndex) 3015 return first; 3016 3017 if (aIndex > bIndex) 3018 return !first; 3019 } 3020 3021 // Iterated the up until the end of the smallest member and 3022 // they were found to be equal up to that point, so select 3023 // the member with the lowest index count, so the "parent" 3024 return memberIndicesA.size() < memberIndicesB.size(); 3025 }); 3026 3027 return llvm::cast<omp::MapInfoOp>( 3028 mapInfo.getMembers()[indices.front()].getDefiningOp()); 3029 } 3030 3031 /// This function calculates the array/pointer offset for map data provided 3032 /// with bounds operations, e.g. when provided something like the following: 3033 /// 3034 /// Fortran 3035 /// map(tofrom: array(2:5, 3:2)) 3036 /// or 3037 /// C++ 3038 /// map(tofrom: array[1:4][2:3]) 3039 /// We must calculate the initial pointer offset to pass across, this function 3040 /// performs this using bounds. 3041 /// 3042 /// NOTE: which while specified in row-major order it currently needs to be 3043 /// flipped for Fortran's column order array allocation and access (as 3044 /// opposed to C++'s row-major, hence the backwards processing where order is 3045 /// important). This is likely important to keep in mind for the future when 3046 /// we incorporate a C++ frontend, both frontends will need to agree on the 3047 /// ordering of generated bounds operations (one may have to flip them) to 3048 /// make the below lowering frontend agnostic. The offload size 3049 /// calcualtion may also have to be adjusted for C++. 3050 std::vector<llvm::Value *> 3051 calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, 3052 llvm::IRBuilderBase &builder, bool isArrayTy, 3053 OperandRange bounds) { 3054 std::vector<llvm::Value *> idx; 3055 // There's no bounds to calculate an offset from, we can safely 3056 // ignore and return no indices. 3057 if (bounds.empty()) 3058 return idx; 3059 3060 // If we have an array type, then we have its type so can treat it as a 3061 // normal GEP instruction where the bounds operations are simply indexes 3062 // into the array. We currently do reverse order of the bounds, which 3063 // I believe leans more towards Fortran's column-major in memory. 3064 if (isArrayTy) { 3065 idx.push_back(builder.getInt64(0)); 3066 for (int i = bounds.size() - 1; i >= 0; --i) { 3067 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>( 3068 bounds[i].getDefiningOp())) { 3069 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound())); 3070 } 3071 } 3072 } else { 3073 // If we do not have an array type, but we have bounds, then we're dealing 3074 // with a pointer that's being treated like an array and we have the 3075 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base 3076 // address (pointer pointing to the actual data) so we must caclulate the 3077 // offset using a single index which the following two loops attempts to 3078 // compute. 3079 3080 // Calculates the size offset we need to make per row e.g. first row or 3081 // column only needs to be offset by one, but the next would have to be 3082 // the previous row/column offset multiplied by the extent of current row. 3083 // 3084 // For example ([1][10][100]): 3085 // 3086 // - First row/column we move by 1 for each index increment 3087 // - Second row/column we move by 1 (first row/column) * 10 (extent/size of 3088 // current) for 10 for each index increment 3089 // - Third row/column we would move by 10 (second row/column) * 3090 // (extent/size of current) 100 for 1000 for each index increment 3091 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)}; 3092 for (size_t i = 1; i < bounds.size(); ++i) { 3093 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>( 3094 bounds[i].getDefiningOp())) { 3095 dimensionIndexSizeOffset.push_back(builder.CreateMul( 3096 moduleTranslation.lookupValue(boundOp.getExtent()), 3097 dimensionIndexSizeOffset[i - 1])); 3098 } 3099 } 3100 3101 // Now that we have calculated how much we move by per index, we must 3102 // multiply each lower bound offset in indexes by the size offset we 3103 // have calculated in the previous and accumulate the results to get 3104 // our final resulting offset. 3105 for (int i = bounds.size() - 1; i >= 0; --i) { 3106 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>( 3107 bounds[i].getDefiningOp())) { 3108 if (idx.empty()) 3109 idx.emplace_back(builder.CreateMul( 3110 moduleTranslation.lookupValue(boundOp.getLowerBound()), 3111 dimensionIndexSizeOffset[i])); 3112 else 3113 idx.back() = builder.CreateAdd( 3114 idx.back(), builder.CreateMul(moduleTranslation.lookupValue( 3115 boundOp.getLowerBound()), 3116 dimensionIndexSizeOffset[i])); 3117 } 3118 } 3119 } 3120 3121 return idx; 3122 } 3123 3124 // This creates two insertions into the MapInfosTy data structure for the 3125 // "parent" of a set of members, (usually a container e.g. 3126 // class/structure/derived type) when subsequent members have also been 3127 // explicitly mapped on the same map clause. Certain types, such as Fortran 3128 // descriptors are mapped like this as well, however, the members are 3129 // implicit as far as a user is concerned, but we must explicitly map them 3130 // internally. 3131 // 3132 // This function also returns the memberOfFlag for this particular parent, 3133 // which is utilised in subsequent member mappings (by modifying there map type 3134 // with it) to indicate that a member is part of this parent and should be 3135 // treated by the runtime as such. Important to achieve the correct mapping. 3136 // 3137 // This function borrows a lot from Clang's emitCombinedEntry function 3138 // inside of CGOpenMPRuntime.cpp 3139 static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( 3140 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, 3141 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, 3142 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, 3143 uint64_t mapDataIndex, bool isTargetParams) { 3144 // Map the first segment of our structure 3145 combinedInfo.Types.emplace_back( 3146 isTargetParams 3147 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM 3148 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE); 3149 combinedInfo.DevicePointers.emplace_back( 3150 mapData.DevicePointers[mapDataIndex]); 3151 combinedInfo.Names.emplace_back(LLVM::createMappingInformation( 3152 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); 3153 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); 3154 3155 // Calculate size of the parent object being mapped based on the 3156 // addresses at runtime, highAddr - lowAddr = size. This of course 3157 // doesn't factor in allocated data like pointers, hence the further 3158 // processing of members specified by users, or in the case of 3159 // Fortran pointers and allocatables, the mapping of the pointed to 3160 // data by the descriptor (which itself, is a structure containing 3161 // runtime information on the dynamically allocated data). 3162 auto parentClause = 3163 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]); 3164 3165 llvm::Value *lowAddr, *highAddr; 3166 if (!parentClause.getPartialMap()) { 3167 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex], 3168 builder.getPtrTy()); 3169 highAddr = builder.CreatePointerCast( 3170 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex], 3171 mapData.Pointers[mapDataIndex], 1), 3172 builder.getPtrTy()); 3173 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]); 3174 } else { 3175 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]); 3176 int firstMemberIdx = getMapDataMemberIdx( 3177 mapData, getFirstOrLastMappedMemberPtr(mapOp, true)); 3178 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx], 3179 builder.getPtrTy()); 3180 int lastMemberIdx = getMapDataMemberIdx( 3181 mapData, getFirstOrLastMappedMemberPtr(mapOp, false)); 3182 highAddr = builder.CreatePointerCast( 3183 builder.CreateGEP(mapData.BaseType[lastMemberIdx], 3184 mapData.Pointers[lastMemberIdx], builder.getInt64(1)), 3185 builder.getPtrTy()); 3186 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]); 3187 } 3188 3189 llvm::Value *size = builder.CreateIntCast( 3190 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr), 3191 builder.getInt64Ty(), 3192 /*isSigned=*/false); 3193 combinedInfo.Sizes.push_back(size); 3194 3195 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag = 3196 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1); 3197 3198 // This creates the initial MEMBER_OF mapping that consists of 3199 // the parent/top level container (same as above effectively, except 3200 // with a fixed initial compile time size and separate maptype which 3201 // indicates the true mape type (tofrom etc.). This parent mapping is 3202 // only relevant if the structure in its totality is being mapped, 3203 // otherwise the above suffices. 3204 if (!parentClause.getPartialMap()) { 3205 // TODO: This will need to be expanded to include the whole host of logic 3206 // for the map flags that Clang currently supports (e.g. it should do some 3207 // further case specific flag modifications). For the moment, it handles 3208 // what we support as expected. 3209 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex]; 3210 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); 3211 combinedInfo.Types.emplace_back(mapFlag); 3212 combinedInfo.DevicePointers.emplace_back( 3213 llvm::OpenMPIRBuilder::DeviceInfoTy::None); 3214 combinedInfo.Names.emplace_back(LLVM::createMappingInformation( 3215 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); 3216 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); 3217 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]); 3218 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]); 3219 } 3220 return memberOfFlag; 3221 } 3222 3223 // The intent is to verify if the mapped data being passed is a 3224 // pointer -> pointee that requires special handling in certain cases, 3225 // e.g. applying the OMP_MAP_PTR_AND_OBJ map type. 3226 // 3227 // There may be a better way to verify this, but unfortunately with 3228 // opaque pointers we lose the ability to easily check if something is 3229 // a pointer whilst maintaining access to the underlying type. 3230 static bool checkIfPointerMap(omp::MapInfoOp mapOp) { 3231 // If we have a varPtrPtr field assigned then the underlying type is a pointer 3232 if (mapOp.getVarPtrPtr()) 3233 return true; 3234 3235 // If the map data is declare target with a link clause, then it's represented 3236 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has 3237 // no relation to pointers. 3238 if (isDeclareTargetLink(mapOp.getVarPtr())) 3239 return true; 3240 3241 return false; 3242 } 3243 3244 // This function is intended to add explicit mappings of members 3245 static void processMapMembersWithParent( 3246 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, 3247 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, 3248 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, 3249 uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) { 3250 3251 auto parentClause = 3252 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]); 3253 3254 for (auto mappedMembers : parentClause.getMembers()) { 3255 auto memberClause = 3256 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp()); 3257 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause); 3258 3259 assert(memberDataIdx >= 0 && "could not find mapped member of structure"); 3260 3261 // If we're currently mapping a pointer to a block of data, we must 3262 // initially map the pointer, and then attatch/bind the data with a 3263 // subsequent map to the pointer. This segment of code generates the 3264 // pointer mapping, which can in certain cases be optimised out as Clang 3265 // currently does in its lowering. However, for the moment we do not do so, 3266 // in part as we currently have substantially less information on the data 3267 // being mapped at this stage. 3268 if (checkIfPointerMap(memberClause)) { 3269 auto mapFlag = llvm::omp::OpenMPOffloadMappingFlags( 3270 memberClause.getMapType().value()); 3271 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; 3272 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF; 3273 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); 3274 combinedInfo.Types.emplace_back(mapFlag); 3275 combinedInfo.DevicePointers.emplace_back( 3276 llvm::OpenMPIRBuilder::DeviceInfoTy::None); 3277 combinedInfo.Names.emplace_back( 3278 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder)); 3279 combinedInfo.BasePointers.emplace_back( 3280 mapData.BasePointers[mapDataIndex]); 3281 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]); 3282 combinedInfo.Sizes.emplace_back(builder.getInt64( 3283 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize())); 3284 } 3285 3286 // Same MemberOfFlag to indicate its link with parent and other members 3287 // of. 3288 auto mapFlag = 3289 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value()); 3290 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; 3291 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF; 3292 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); 3293 if (checkIfPointerMap(memberClause)) 3294 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ; 3295 3296 combinedInfo.Types.emplace_back(mapFlag); 3297 combinedInfo.DevicePointers.emplace_back( 3298 mapData.DevicePointers[memberDataIdx]); 3299 combinedInfo.Names.emplace_back( 3300 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder)); 3301 uint64_t basePointerIndex = 3302 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex; 3303 combinedInfo.BasePointers.emplace_back( 3304 mapData.BasePointers[basePointerIndex]); 3305 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]); 3306 3307 llvm::Value *size = mapData.Sizes[memberDataIdx]; 3308 if (checkIfPointerMap(memberClause)) { 3309 size = builder.CreateSelect( 3310 builder.CreateIsNull(mapData.Pointers[memberDataIdx]), 3311 builder.getInt64(0), size); 3312 } 3313 3314 combinedInfo.Sizes.emplace_back(size); 3315 } 3316 } 3317 3318 static void 3319 processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, 3320 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, 3321 bool isTargetParams, int mapDataParentIdx = -1) { 3322 // Declare Target Mappings are excluded from being marked as 3323 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're 3324 // marked with OMP_MAP_PTR_AND_OBJ instead. 3325 auto mapFlag = mapData.Types[mapDataIdx]; 3326 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]); 3327 3328 bool isPtrTy = checkIfPointerMap(mapInfoOp); 3329 if (isPtrTy) 3330 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ; 3331 3332 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx]) 3333 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; 3334 3335 if (mapInfoOp.getMapCaptureType().value() == 3336 omp::VariableCaptureKind::ByCopy && 3337 !isPtrTy) 3338 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL; 3339 3340 // if we're provided a mapDataParentIdx, then the data being mapped is 3341 // part of a larger object (in a parent <-> member mapping) and in this 3342 // case our BasePointer should be the parent. 3343 if (mapDataParentIdx >= 0) 3344 combinedInfo.BasePointers.emplace_back( 3345 mapData.BasePointers[mapDataParentIdx]); 3346 else 3347 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]); 3348 3349 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]); 3350 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]); 3351 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]); 3352 combinedInfo.Types.emplace_back(mapFlag); 3353 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]); 3354 } 3355 3356 static void processMapWithMembersOf( 3357 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, 3358 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, 3359 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, 3360 uint64_t mapDataIndex, bool isTargetParams) { 3361 auto parentClause = 3362 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]); 3363 3364 // If we have a partial map (no parent referenced in the map clauses of the 3365 // directive, only members) and only a single member, we do not need to bind 3366 // the map of the member to the parent, we can pass the member separately. 3367 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) { 3368 auto memberClause = llvm::cast<omp::MapInfoOp>( 3369 parentClause.getMembers()[0].getDefiningOp()); 3370 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause); 3371 // Note: Clang treats arrays with explicit bounds that fall into this 3372 // category as a parent with map case, however, it seems this isn't a 3373 // requirement, and processing them as an individual map is fine. So, 3374 // we will handle them as individual maps for the moment, as it's 3375 // difficult for us to check this as we always require bounds to be 3376 // specified currently and it's also marginally more optimal (single 3377 // map rather than two). The difference may come from the fact that 3378 // Clang maps array without bounds as pointers (which we do not 3379 // currently do), whereas we treat them as arrays in all cases 3380 // currently. 3381 processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams, 3382 mapDataIndex); 3383 return; 3384 } 3385 3386 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag = 3387 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl, 3388 combinedInfo, mapData, mapDataIndex, isTargetParams); 3389 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl, 3390 combinedInfo, mapData, mapDataIndex, 3391 memberOfParentFlag); 3392 } 3393 3394 // This is a variation on Clang's GenerateOpenMPCapturedVars, which 3395 // generates different operation (e.g. load/store) combinations for 3396 // arguments to the kernel, based on map capture kinds which are then 3397 // utilised in the combinedInfo in place of the original Map value. 3398 static void 3399 createAlteredByCaptureMap(MapInfoData &mapData, 3400 LLVM::ModuleTranslation &moduleTranslation, 3401 llvm::IRBuilderBase &builder) { 3402 for (size_t i = 0; i < mapData.MapClause.size(); ++i) { 3403 // if it's declare target, skip it, it's handled separately. 3404 if (!mapData.IsDeclareTarget[i]) { 3405 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]); 3406 omp::VariableCaptureKind captureKind = 3407 mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef); 3408 bool isPtrTy = checkIfPointerMap(mapOp); 3409 3410 // Currently handles array sectioning lowerbound case, but more 3411 // logic may be required in the future. Clang invokes EmitLValue, 3412 // which has specialised logic for special Clang types such as user 3413 // defines, so it is possible we will have to extend this for 3414 // structures or other complex types. As the general idea is that this 3415 // function mimics some of the logic from Clang that we require for 3416 // kernel argument passing from host -> device. 3417 switch (captureKind) { 3418 case omp::VariableCaptureKind::ByRef: { 3419 llvm::Value *newV = mapData.Pointers[i]; 3420 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset( 3421 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(), 3422 mapOp.getBounds()); 3423 if (isPtrTy) 3424 newV = builder.CreateLoad(builder.getPtrTy(), newV); 3425 3426 if (!offsetIdx.empty()) 3427 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx, 3428 "array_offset"); 3429 mapData.Pointers[i] = newV; 3430 } break; 3431 case omp::VariableCaptureKind::ByCopy: { 3432 llvm::Type *type = mapData.BaseType[i]; 3433 llvm::Value *newV; 3434 if (mapData.Pointers[i]->getType()->isPointerTy()) 3435 newV = builder.CreateLoad(type, mapData.Pointers[i]); 3436 else 3437 newV = mapData.Pointers[i]; 3438 3439 if (!isPtrTy) { 3440 auto curInsert = builder.saveIP(); 3441 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation)); 3442 auto *memTempAlloc = 3443 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted"); 3444 builder.restoreIP(curInsert); 3445 3446 builder.CreateStore(newV, memTempAlloc); 3447 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc); 3448 } 3449 3450 mapData.Pointers[i] = newV; 3451 mapData.BasePointers[i] = newV; 3452 } break; 3453 case omp::VariableCaptureKind::This: 3454 case omp::VariableCaptureKind::VLAType: 3455 mapData.MapClause[i]->emitOpError("Unhandled capture kind"); 3456 break; 3457 } 3458 } 3459 } 3460 } 3461 3462 // Generate all map related information and fill the combinedInfo. 3463 static void genMapInfos(llvm::IRBuilderBase &builder, 3464 LLVM::ModuleTranslation &moduleTranslation, 3465 DataLayout &dl, 3466 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, 3467 MapInfoData &mapData, bool isTargetParams = false) { 3468 // We wish to modify some of the methods in which arguments are 3469 // passed based on their capture type by the target region, this can 3470 // involve generating new loads and stores, which changes the 3471 // MLIR value to LLVM value mapping, however, we only wish to do this 3472 // locally for the current function/target and also avoid altering 3473 // ModuleTranslation, so we remap the base pointer or pointer stored 3474 // in the map infos corresponding MapInfoData, which is later accessed 3475 // by genMapInfos and createTarget to help generate the kernel and 3476 // kernel arg structure. It primarily becomes relevant in cases like 3477 // bycopy, or byref range'd arrays. In the default case, we simply 3478 // pass thee pointer byref as both basePointer and pointer. 3479 if (!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice()) 3480 createAlteredByCaptureMap(mapData, moduleTranslation, builder); 3481 3482 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 3483 3484 // We operate under the assumption that all vectors that are 3485 // required in MapInfoData are of equal lengths (either filled with 3486 // default constructed data or appropiate information) so we can 3487 // utilise the size from any component of MapInfoData, if we can't 3488 // something is missing from the initial MapInfoData construction. 3489 for (size_t i = 0; i < mapData.MapClause.size(); ++i) { 3490 // NOTE/TODO: We currently do not support arbitrary depth record 3491 // type mapping. 3492 if (mapData.IsAMember[i]) 3493 continue; 3494 3495 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]); 3496 if (!mapInfoOp.getMembers().empty()) { 3497 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl, 3498 combinedInfo, mapData, i, isTargetParams); 3499 continue; 3500 } 3501 3502 processIndividualMap(mapData, i, combinedInfo, isTargetParams); 3503 } 3504 } 3505 3506 static LogicalResult 3507 convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, 3508 LLVM::ModuleTranslation &moduleTranslation) { 3509 llvm::Value *ifCond = nullptr; 3510 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF; 3511 SmallVector<Value> mapVars; 3512 SmallVector<Value> useDevicePtrVars; 3513 SmallVector<Value> useDeviceAddrVars; 3514 llvm::omp::RuntimeFunction RTLFn; 3515 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>()); 3516 3517 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 3518 llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true, 3519 /*SeparateBeginEndCalls=*/true); 3520 3521 LogicalResult result = 3522 llvm::TypeSwitch<Operation *, LogicalResult>(op) 3523 .Case([&](omp::TargetDataOp dataOp) { 3524 if (failed(checkImplementationStatus(*dataOp))) 3525 return failure(); 3526 3527 if (auto ifVar = dataOp.getIfExpr()) 3528 ifCond = moduleTranslation.lookupValue(ifVar); 3529 3530 if (auto devId = dataOp.getDevice()) 3531 if (auto constOp = 3532 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp())) 3533 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) 3534 deviceID = intAttr.getInt(); 3535 3536 mapVars = dataOp.getMapVars(); 3537 useDevicePtrVars = dataOp.getUseDevicePtrVars(); 3538 useDeviceAddrVars = dataOp.getUseDeviceAddrVars(); 3539 return success(); 3540 }) 3541 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult { 3542 if (failed(checkImplementationStatus(*enterDataOp))) 3543 return failure(); 3544 3545 if (auto ifVar = enterDataOp.getIfExpr()) 3546 ifCond = moduleTranslation.lookupValue(ifVar); 3547 3548 if (auto devId = enterDataOp.getDevice()) 3549 if (auto constOp = 3550 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp())) 3551 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) 3552 deviceID = intAttr.getInt(); 3553 RTLFn = 3554 enterDataOp.getNowait() 3555 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper 3556 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper; 3557 mapVars = enterDataOp.getMapVars(); 3558 info.HasNoWait = enterDataOp.getNowait(); 3559 return success(); 3560 }) 3561 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult { 3562 if (failed(checkImplementationStatus(*exitDataOp))) 3563 return failure(); 3564 3565 if (auto ifVar = exitDataOp.getIfExpr()) 3566 ifCond = moduleTranslation.lookupValue(ifVar); 3567 3568 if (auto devId = exitDataOp.getDevice()) 3569 if (auto constOp = 3570 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp())) 3571 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) 3572 deviceID = intAttr.getInt(); 3573 3574 RTLFn = exitDataOp.getNowait() 3575 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper 3576 : llvm::omp::OMPRTL___tgt_target_data_end_mapper; 3577 mapVars = exitDataOp.getMapVars(); 3578 info.HasNoWait = exitDataOp.getNowait(); 3579 return success(); 3580 }) 3581 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult { 3582 if (failed(checkImplementationStatus(*updateDataOp))) 3583 return failure(); 3584 3585 if (auto ifVar = updateDataOp.getIfExpr()) 3586 ifCond = moduleTranslation.lookupValue(ifVar); 3587 3588 if (auto devId = updateDataOp.getDevice()) 3589 if (auto constOp = 3590 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp())) 3591 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) 3592 deviceID = intAttr.getInt(); 3593 3594 RTLFn = 3595 updateDataOp.getNowait() 3596 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper 3597 : llvm::omp::OMPRTL___tgt_target_data_update_mapper; 3598 mapVars = updateDataOp.getMapVars(); 3599 info.HasNoWait = updateDataOp.getNowait(); 3600 return success(); 3601 }) 3602 .Default([&](Operation *op) { 3603 llvm_unreachable("unexpected operation"); 3604 return failure(); 3605 }); 3606 3607 if (failed(result)) 3608 return failure(); 3609 3610 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 3611 3612 MapInfoData mapData; 3613 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL, 3614 builder, useDevicePtrVars, useDeviceAddrVars); 3615 3616 // Fill up the arrays with all the mapped variables. 3617 llvm::OpenMPIRBuilder::MapInfosTy combinedInfo; 3618 auto genMapInfoCB = 3619 [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { 3620 builder.restoreIP(codeGenIP); 3621 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData); 3622 return combinedInfo; 3623 }; 3624 3625 // Define a lambda to apply mappings between use_device_addr and 3626 // use_device_ptr base pointers, and their associated block arguments. 3627 auto mapUseDevice = 3628 [&moduleTranslation]( 3629 llvm::OpenMPIRBuilder::DeviceInfoTy type, 3630 llvm::ArrayRef<BlockArgument> blockArgs, 3631 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData, 3632 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) { 3633 for (auto [arg, useDevVar] : 3634 llvm::zip_equal(blockArgs, useDeviceVars)) { 3635 3636 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) { 3637 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr() 3638 : mapInfoOp.getVarPtr(); 3639 }; 3640 3641 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp()); 3642 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal( 3643 mapInfoData.MapClause, mapInfoData.DevicePointers, 3644 mapInfoData.BasePointers)) { 3645 auto mapOp = cast<omp::MapInfoOp>(mapClause); 3646 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) || 3647 devicePointer != type) 3648 continue; 3649 3650 if (llvm::Value *devPtrInfoMap = 3651 mapper ? mapper(basePointer) : basePointer) { 3652 moduleTranslation.mapValue(arg, devPtrInfoMap); 3653 break; 3654 } 3655 } 3656 } 3657 }; 3658 3659 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy; 3660 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) 3661 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { 3662 assert(isa<omp::TargetDataOp>(op) && 3663 "BodyGen requested for non TargetDataOp"); 3664 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op); 3665 Region ®ion = cast<omp::TargetDataOp>(op).getRegion(); 3666 switch (bodyGenType) { 3667 case BodyGenTy::Priv: 3668 // Check if any device ptr/addr info is available 3669 if (!info.DevicePtrInfoMap.empty()) { 3670 builder.restoreIP(codeGenIP); 3671 3672 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address, 3673 blockArgIface.getUseDeviceAddrBlockArgs(), 3674 useDeviceAddrVars, mapData, 3675 [&](llvm::Value *basePointer) -> llvm::Value * { 3676 if (!info.DevicePtrInfoMap[basePointer].second) 3677 return nullptr; 3678 return builder.CreateLoad( 3679 builder.getPtrTy(), 3680 info.DevicePtrInfoMap[basePointer].second); 3681 }); 3682 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer, 3683 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars, 3684 mapData, [&](llvm::Value *basePointer) { 3685 return info.DevicePtrInfoMap[basePointer].second; 3686 }); 3687 3688 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder, 3689 moduleTranslation))) 3690 return llvm::make_error<PreviouslyReportedError>(); 3691 } 3692 break; 3693 case BodyGenTy::DupNoPriv: 3694 break; 3695 case BodyGenTy::NoPriv: 3696 // If device info is available then region has already been generated 3697 if (info.DevicePtrInfoMap.empty()) { 3698 builder.restoreIP(codeGenIP); 3699 // For device pass, if use_device_ptr(addr) mappings were present, 3700 // we need to link them here before codegen. 3701 if (ompBuilder->Config.IsTargetDevice.value_or(false)) { 3702 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address, 3703 blockArgIface.getUseDeviceAddrBlockArgs(), 3704 useDeviceAddrVars, mapData); 3705 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer, 3706 blockArgIface.getUseDevicePtrBlockArgs(), 3707 useDevicePtrVars, mapData); 3708 } 3709 3710 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder, 3711 moduleTranslation))) 3712 return llvm::make_error<PreviouslyReportedError>(); 3713 } 3714 break; 3715 } 3716 return builder.saveIP(); 3717 }; 3718 3719 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 3720 llvm::OpenMPIRBuilder::InsertPointTy allocaIP = 3721 findAllocaInsertPoint(builder, moduleTranslation); 3722 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() { 3723 if (isa<omp::TargetDataOp>(op)) 3724 return ompBuilder->createTargetData( 3725 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), 3726 ifCond, info, genMapInfoCB, nullptr, bodyGenCB); 3727 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(), 3728 builder.getInt64(deviceID), ifCond, 3729 info, genMapInfoCB, &RTLFn); 3730 }(); 3731 3732 if (failed(handleError(afterIP, *op))) 3733 return failure(); 3734 3735 builder.restoreIP(*afterIP); 3736 return success(); 3737 } 3738 3739 /// Lowers the FlagsAttr which is applied to the module on the device 3740 /// pass when offloading, this attribute contains OpenMP RTL globals that can 3741 /// be passed as flags to the frontend, otherwise they are set to default 3742 LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, 3743 LLVM::ModuleTranslation &moduleTranslation) { 3744 if (!cast<mlir::ModuleOp>(op)) 3745 return failure(); 3746 3747 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 3748 3749 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device", 3750 attribute.getOpenmpDeviceVersion()); 3751 3752 if (attribute.getNoGpuLib()) 3753 return success(); 3754 3755 ompBuilder->createGlobalFlag( 3756 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/, 3757 "__omp_rtl_debug_kind"); 3758 ompBuilder->createGlobalFlag( 3759 attribute 3760 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/ 3761 , 3762 "__omp_rtl_assume_teams_oversubscription"); 3763 ompBuilder->createGlobalFlag( 3764 attribute 3765 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/ 3766 , 3767 "__omp_rtl_assume_threads_oversubscription"); 3768 ompBuilder->createGlobalFlag( 3769 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/, 3770 "__omp_rtl_assume_no_thread_state"); 3771 ompBuilder->createGlobalFlag( 3772 attribute 3773 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/ 3774 , 3775 "__omp_rtl_assume_no_nested_parallelism"); 3776 return success(); 3777 } 3778 3779 static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, 3780 omp::TargetOp targetOp, 3781 llvm::StringRef parentName = "") { 3782 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>(); 3783 3784 assert(fileLoc && "No file found from location"); 3785 StringRef fileName = fileLoc.getFilename().getValue(); 3786 3787 llvm::sys::fs::UniqueID id; 3788 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) { 3789 targetOp.emitError("Unable to get unique ID for file"); 3790 return false; 3791 } 3792 3793 uint64_t line = fileLoc.getLine(); 3794 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(), 3795 id.getFile(), line); 3796 return true; 3797 } 3798 3799 static void 3800 handleDeclareTargetMapVar(MapInfoData &mapData, 3801 LLVM::ModuleTranslation &moduleTranslation, 3802 llvm::IRBuilderBase &builder, llvm::Function *func) { 3803 for (size_t i = 0; i < mapData.MapClause.size(); ++i) { 3804 // In the case of declare target mapped variables, the basePointer is 3805 // the reference pointer generated by the convertDeclareTargetAttr 3806 // method. Whereas the kernelValue is the original variable, so for 3807 // the device we must replace all uses of this original global variable 3808 // (stored in kernelValue) with the reference pointer (stored in 3809 // basePointer for declare target mapped variables), as for device the 3810 // data is mapped into this reference pointer and should be loaded 3811 // from it, the original variable is discarded. On host both exist and 3812 // metadata is generated (elsewhere in the convertDeclareTargetAttr) 3813 // function to link the two variables in the runtime and then both the 3814 // reference pointer and the pointer are assigned in the kernel argument 3815 // structure for the host. 3816 if (mapData.IsDeclareTarget[i]) { 3817 // If the original map value is a constant, then we have to make sure all 3818 // of it's uses within the current kernel/function that we are going to 3819 // rewrite are converted to instructions, as we will be altering the old 3820 // use (OriginalValue) from a constant to an instruction, which will be 3821 // illegal and ICE the compiler if the user is a constant expression of 3822 // some kind e.g. a constant GEP. 3823 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i])) 3824 convertUsersOfConstantsToInstructions(constant, func, false); 3825 3826 // The users iterator will get invalidated if we modify an element, 3827 // so we populate this vector of uses to alter each user on an 3828 // individual basis to emit its own load (rather than one load for 3829 // all). 3830 llvm::SmallVector<llvm::User *> userVec; 3831 for (llvm::User *user : mapData.OriginalValue[i]->users()) 3832 userVec.push_back(user); 3833 3834 for (llvm::User *user : userVec) { 3835 if (auto *insn = dyn_cast<llvm::Instruction>(user)) { 3836 if (insn->getFunction() == func) { 3837 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(), 3838 mapData.BasePointers[i]); 3839 load->moveBefore(insn->getIterator()); 3840 user->replaceUsesOfWith(mapData.OriginalValue[i], load); 3841 } 3842 } 3843 } 3844 } 3845 } 3846 } 3847 3848 // The createDeviceArgumentAccessor function generates 3849 // instructions for retrieving (acessing) kernel 3850 // arguments inside of the device kernel for use by 3851 // the kernel. This enables different semantics such as 3852 // the creation of temporary copies of data allowing 3853 // semantics like read-only/no host write back kernel 3854 // arguments. 3855 // 3856 // This currently implements a very light version of Clang's 3857 // EmitParmDecl's handling of direct argument handling as well 3858 // as a portion of the argument access generation based on 3859 // capture types found at the end of emitOutlinedFunctionPrologue 3860 // in Clang. The indirect path handling of EmitParmDecl's may be 3861 // required for future work, but a direct 1-to-1 copy doesn't seem 3862 // possible as the logic is rather scattered throughout Clang's 3863 // lowering and perhaps we wish to deviate slightly. 3864 // 3865 // \param mapData - A container containing vectors of information 3866 // corresponding to the input argument, which should have a 3867 // corresponding entry in the MapInfoData containers 3868 // OrigialValue's. 3869 // \param arg - This is the generated kernel function argument that 3870 // corresponds to the passed in input argument. We generated different 3871 // accesses of this Argument, based on capture type and other Input 3872 // related information. 3873 // \param input - This is the host side value that will be passed to 3874 // the kernel i.e. the kernel input, we rewrite all uses of this within 3875 // the kernel (as we generate the kernel body based on the target's region 3876 // which maintians references to the original input) to the retVal argument 3877 // apon exit of this function inside of the OMPIRBuilder. This interlinks 3878 // the kernel argument to future uses of it in the function providing 3879 // appropriate "glue" instructions inbetween. 3880 // \param retVal - This is the value that all uses of input inside of the 3881 // kernel will be re-written to, the goal of this function is to generate 3882 // an appropriate location for the kernel argument to be accessed from, 3883 // e.g. ByRef will result in a temporary allocation location and then 3884 // a store of the kernel argument into this allocated memory which 3885 // will then be loaded from, ByCopy will use the allocated memory 3886 // directly. 3887 static llvm::IRBuilderBase::InsertPoint 3888 createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, 3889 llvm::Value *input, llvm::Value *&retVal, 3890 llvm::IRBuilderBase &builder, 3891 llvm::OpenMPIRBuilder &ompBuilder, 3892 LLVM::ModuleTranslation &moduleTranslation, 3893 llvm::IRBuilderBase::InsertPoint allocaIP, 3894 llvm::IRBuilderBase::InsertPoint codeGenIP) { 3895 builder.restoreIP(allocaIP); 3896 3897 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef; 3898 3899 // Find the associated MapInfoData entry for the current input 3900 for (size_t i = 0; i < mapData.MapClause.size(); ++i) 3901 if (mapData.OriginalValue[i] == input) { 3902 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]); 3903 capture = 3904 mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef); 3905 3906 break; 3907 } 3908 3909 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace(); 3910 unsigned int defaultAS = 3911 ompBuilder.M.getDataLayout().getProgramAddressSpace(); 3912 3913 // Create the alloca for the argument the current point. 3914 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS); 3915 3916 if (allocaAS != defaultAS && arg.getType()->isPointerTy()) 3917 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS)); 3918 3919 builder.CreateStore(&arg, v); 3920 3921 builder.restoreIP(codeGenIP); 3922 3923 switch (capture) { 3924 case omp::VariableCaptureKind::ByCopy: { 3925 retVal = v; 3926 break; 3927 } 3928 case omp::VariableCaptureKind::ByRef: { 3929 retVal = builder.CreateAlignedLoad( 3930 v->getType(), v, 3931 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType())); 3932 break; 3933 } 3934 case omp::VariableCaptureKind::This: 3935 case omp::VariableCaptureKind::VLAType: 3936 // TODO: Consider returning error to use standard reporting for 3937 // unimplemented features. 3938 assert(false && "Currently unsupported capture kind"); 3939 break; 3940 } 3941 3942 return builder.saveIP(); 3943 } 3944 3945 /// Follow uses of `host_eval`-defined block arguments of the given `omp.target` 3946 /// operation and populate output variables with their corresponding host value 3947 /// (i.e. operand evaluated outside of the target region), based on their uses 3948 /// inside of the target region. 3949 /// 3950 /// Loop bounds and steps are only optionally populated, if output vectors are 3951 /// provided. 3952 static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, 3953 Value &numTeamsLower, Value &numTeamsUpper, 3954 Value &threadLimit) { 3955 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp); 3956 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(), 3957 blockArgIface.getHostEvalBlockArgs())) { 3958 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item); 3959 3960 for (Operation *user : blockArg.getUsers()) { 3961 llvm::TypeSwitch<Operation *>(user) 3962 .Case([&](omp::TeamsOp teamsOp) { 3963 if (teamsOp.getNumTeamsLower() == blockArg) 3964 numTeamsLower = hostEvalVar; 3965 else if (teamsOp.getNumTeamsUpper() == blockArg) 3966 numTeamsUpper = hostEvalVar; 3967 else if (teamsOp.getThreadLimit() == blockArg) 3968 threadLimit = hostEvalVar; 3969 else 3970 llvm_unreachable("unsupported host_eval use"); 3971 }) 3972 .Case([&](omp::ParallelOp parallelOp) { 3973 if (parallelOp.getNumThreads() == blockArg) 3974 numThreads = hostEvalVar; 3975 else 3976 llvm_unreachable("unsupported host_eval use"); 3977 }) 3978 .Case([&](omp::LoopNestOp loopOp) { 3979 // TODO: Extract bounds and step values. Currently, this cannot be 3980 // reached because translation would have been stopped earlier as a 3981 // result of `checkImplementationStatus` detecting and reporting 3982 // this situation. 3983 llvm_unreachable("unsupported host_eval use"); 3984 }) 3985 .Default([](Operation *) { 3986 llvm_unreachable("unsupported host_eval use"); 3987 }); 3988 } 3989 } 3990 } 3991 3992 /// If \p op is of the given type parameter, return it casted to that type. 3993 /// Otherwise, if its immediate parent operation (or some other higher-level 3994 /// parent, if \p immediateParent is false) is of that type, return that parent 3995 /// casted to the given type. 3996 /// 3997 /// If \p op is \c null or neither it or its parent(s) are of the specified 3998 /// type, return a \c null operation. 3999 template <typename OpTy> 4000 static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) { 4001 if (!op) 4002 return OpTy(); 4003 4004 if (OpTy casted = dyn_cast<OpTy>(op)) 4005 return casted; 4006 4007 if (immediateParent) 4008 return dyn_cast_if_present<OpTy>(op->getParentOp()); 4009 4010 return op->getParentOfType<OpTy>(); 4011 } 4012 4013 /// If the given \p value is defined by an \c llvm.mlir.constant operation and 4014 /// it is of an integer type, return its value. 4015 static std::optional<int64_t> extractConstInteger(Value value) { 4016 if (!value) 4017 return std::nullopt; 4018 4019 if (auto constOp = 4020 dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp())) 4021 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue())) 4022 return constAttr.getInt(); 4023 4024 return std::nullopt; 4025 } 4026 4027 /// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default 4028 /// values as stated by the corresponding clauses, if constant. 4029 /// 4030 /// These default values must be set before the creation of the outlined LLVM 4031 /// function for the target region, so that they can be used to initialize the 4032 /// corresponding global `ConfigurationEnvironmentTy` structure. 4033 static void 4034 initTargetDefaultAttrs(omp::TargetOp targetOp, 4035 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, 4036 bool isTargetDevice) { 4037 // TODO: Handle constant 'if' clauses. 4038 Operation *capturedOp = targetOp.getInnermostCapturedOmpOp(); 4039 4040 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit; 4041 if (!isTargetDevice) { 4042 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper, 4043 threadLimit); 4044 } else { 4045 // In the target device, values for these clauses are not passed as 4046 // host_eval, but instead evaluated prior to entry to the region. This 4047 // ensures values are mapped and available inside of the target region. 4048 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) { 4049 numTeamsLower = teamsOp.getNumTeamsLower(); 4050 numTeamsUpper = teamsOp.getNumTeamsUpper(); 4051 threadLimit = teamsOp.getThreadLimit(); 4052 } 4053 4054 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) 4055 numThreads = parallelOp.getNumThreads(); 4056 } 4057 4058 // Handle clauses impacting the number of teams. 4059 4060 int32_t minTeamsVal = 1, maxTeamsVal = -1; 4061 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) { 4062 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match 4063 // clang and set min and max to the same value. 4064 if (numTeamsUpper) { 4065 if (auto val = extractConstInteger(numTeamsUpper)) 4066 minTeamsVal = maxTeamsVal = *val; 4067 } else { 4068 minTeamsVal = maxTeamsVal = 0; 4069 } 4070 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp, 4071 /*immediateParent=*/true) || 4072 castOrGetParentOfType<omp::SimdOp>(capturedOp, 4073 /*immediateParent=*/true)) { 4074 minTeamsVal = maxTeamsVal = 1; 4075 } else { 4076 minTeamsVal = maxTeamsVal = -1; 4077 } 4078 4079 // Handle clauses impacting the number of threads. 4080 4081 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) { 4082 if (!clauseValue) 4083 return; 4084 4085 if (auto val = extractConstInteger(clauseValue)) 4086 result = *val; 4087 4088 // Found an applicable clause, so it's not undefined. Mark as unknown 4089 // because it's not constant. 4090 if (result < 0) 4091 result = 0; 4092 }; 4093 4094 // Extract 'thread_limit' clause from 'target' and 'teams' directives. 4095 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1; 4096 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal); 4097 setMaxValueFromClause(threadLimit, teamsThreadLimitVal); 4098 4099 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD. 4100 int32_t maxThreadsVal = -1; 4101 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp)) 4102 setMaxValueFromClause(numThreads, maxThreadsVal); 4103 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp, 4104 /*immediateParent=*/true)) 4105 maxThreadsVal = 1; 4106 4107 // For max values, < 0 means unset, == 0 means set but unknown. Select the 4108 // minimum value between 'max_threads' and 'thread_limit' clauses that were 4109 // set. 4110 int32_t combinedMaxThreadsVal = targetThreadLimitVal; 4111 if (combinedMaxThreadsVal < 0 || 4112 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal)) 4113 combinedMaxThreadsVal = teamsThreadLimitVal; 4114 4115 if (combinedMaxThreadsVal < 0 || 4116 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal)) 4117 combinedMaxThreadsVal = maxThreadsVal; 4118 4119 // Update kernel bounds structure for the `OpenMPIRBuilder` to use. 4120 attrs.MinTeams = minTeamsVal; 4121 attrs.MaxTeams.front() = maxTeamsVal; 4122 attrs.MinThreads = 1; 4123 attrs.MaxThreads.front() = combinedMaxThreadsVal; 4124 } 4125 4126 /// Gather LLVM runtime values for all clauses evaluated in the host that are 4127 /// passed to the kernel invocation. 4128 /// 4129 /// This function must be called only when compiling for the host. Also, it will 4130 /// only provide correct results if it's called after the body of \c targetOp 4131 /// has been fully generated. 4132 static void 4133 initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, 4134 LLVM::ModuleTranslation &moduleTranslation, 4135 omp::TargetOp targetOp, 4136 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) { 4137 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit; 4138 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper, 4139 teamsThreadLimit); 4140 4141 // TODO: Handle constant 'if' clauses. 4142 if (Value targetThreadLimit = targetOp.getThreadLimit()) 4143 attrs.TargetThreadLimit.front() = 4144 moduleTranslation.lookupValue(targetThreadLimit); 4145 4146 if (numTeamsLower) 4147 attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower); 4148 4149 if (numTeamsUpper) 4150 attrs.MaxTeams.front() = moduleTranslation.lookupValue(numTeamsUpper); 4151 4152 if (teamsThreadLimit) 4153 attrs.TeamsThreadLimit.front() = 4154 moduleTranslation.lookupValue(teamsThreadLimit); 4155 4156 if (numThreads) 4157 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads); 4158 4159 // TODO: Populate attrs.LoopTripCount if it is target SPMD. 4160 } 4161 4162 static LogicalResult 4163 convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, 4164 LLVM::ModuleTranslation &moduleTranslation) { 4165 auto targetOp = cast<omp::TargetOp>(opInst); 4166 if (failed(checkImplementationStatus(opInst))) 4167 return failure(); 4168 4169 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 4170 bool isTargetDevice = ompBuilder->Config.isTargetDevice(); 4171 4172 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>(); 4173 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst); 4174 auto &targetRegion = targetOp.getRegion(); 4175 // Holds the private vars that have been mapped along with the block argument 4176 // that corresponds to the MapInfoOp corresponding to the private var in 4177 // question. So, for instance: 4178 // 4179 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..) 4180 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1) 4181 // 4182 // Then, %10 has been created so that the descriptor can be used by the 4183 // privatizer @box.privatizer on the device side. Here we'd record {%6#0, 4184 // %arg0} in the mappedPrivateVars map. 4185 llvm::DenseMap<Value, Value> mappedPrivateVars; 4186 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>()); 4187 SmallVector<Value> mapVars = targetOp.getMapVars(); 4188 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs(); 4189 llvm::Function *llvmOutlinedFn = nullptr; 4190 4191 // TODO: It can also be false if a compile-time constant `false` IF clause is 4192 // specified. 4193 bool isOffloadEntry = 4194 isTargetDevice || !ompBuilder->Config.TargetTriples.empty(); 4195 4196 // For some private variables, the MapsForPrivatizedVariablesPass 4197 // creates MapInfoOp instances. Go through the private variables and 4198 // the mapped variables so that during codegeneration we are able 4199 // to quickly look up the corresponding map variable, if any for each 4200 // private variable. 4201 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) { 4202 OperandRange privateVars = targetOp.getPrivateVars(); 4203 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms(); 4204 std::optional<DenseI64ArrayAttr> privateMapIndices = 4205 targetOp.getPrivateMapsAttr(); 4206 4207 for (auto [privVarIdx, privVarSymPair] : 4208 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) { 4209 auto privVar = std::get<0>(privVarSymPair); 4210 auto privSym = std::get<1>(privVarSymPair); 4211 4212 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym); 4213 omp::PrivateClauseOp privatizer = 4214 findPrivatizer(targetOp, privatizerName); 4215 4216 if (!privatizer.needsMap()) 4217 continue; 4218 4219 mlir::Value mappedValue = 4220 targetOp.getMappedValueForPrivateVar(privVarIdx); 4221 assert(mappedValue && "Expected to find mapped value for a privatized " 4222 "variable that needs mapping"); 4223 4224 // The MapInfoOp defining the map var isn't really needed later. 4225 // So, we don't store it in any datastructure. Instead, we just 4226 // do some sanity checks on it right now. 4227 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>(); 4228 [[maybe_unused]] Type varType = mapInfoOp.getVarType(); 4229 4230 // Check #1: Check that the type of the private variable matches 4231 // the type of the variable being mapped. 4232 if (!isa<LLVM::LLVMPointerType>(privVar.getType())) 4233 assert( 4234 varType == privVar.getType() && 4235 "Type of private var doesn't match the type of the mapped value"); 4236 4237 // Ok, only 1 sanity check for now. 4238 // Record the block argument corresponding to this mapvar. 4239 mappedPrivateVars.insert( 4240 {privVar, 4241 targetRegion.getArgument(argIface.getMapBlockArgsStart() + 4242 (*privateMapIndices)[privVarIdx])}); 4243 } 4244 } 4245 4246 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 4247 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) 4248 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { 4249 // Forward target-cpu and target-features function attributes from the 4250 // original function to the new outlined function. 4251 llvm::Function *llvmParentFn = 4252 moduleTranslation.lookupFunction(parentFn.getName()); 4253 llvmOutlinedFn = codeGenIP.getBlock()->getParent(); 4254 assert(llvmParentFn && llvmOutlinedFn && 4255 "Both parent and outlined functions must exist at this point"); 4256 4257 if (auto attr = llvmParentFn->getFnAttribute("target-cpu"); 4258 attr.isStringAttribute()) 4259 llvmOutlinedFn->addFnAttr(attr); 4260 4261 if (auto attr = llvmParentFn->getFnAttribute("target-features"); 4262 attr.isStringAttribute()) 4263 llvmOutlinedFn->addFnAttr(attr); 4264 4265 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) { 4266 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp()); 4267 llvm::Value *mapOpValue = 4268 moduleTranslation.lookupValue(mapInfoOp.getVarPtr()); 4269 moduleTranslation.mapValue(arg, mapOpValue); 4270 } 4271 4272 // Do privatization after moduleTranslation has already recorded 4273 // mapped values. 4274 MutableArrayRef<BlockArgument> privateBlockArgs = 4275 argIface.getPrivateBlockArgs(); 4276 SmallVector<mlir::Value> mlirPrivateVars; 4277 SmallVector<llvm::Value *> llvmPrivateVars; 4278 SmallVector<omp::PrivateClauseOp> privateDecls; 4279 mlirPrivateVars.reserve(privateBlockArgs.size()); 4280 llvmPrivateVars.reserve(privateBlockArgs.size()); 4281 collectPrivatizationDecls(targetOp, privateDecls); 4282 for (mlir::Value privateVar : targetOp.getPrivateVars()) 4283 mlirPrivateVars.push_back(privateVar); 4284 4285 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars( 4286 builder, moduleTranslation, privateBlockArgs, privateDecls, 4287 mlirPrivateVars, llvmPrivateVars, allocaIP, &mappedPrivateVars); 4288 4289 if (failed(handleError(afterAllocas, *targetOp))) 4290 return llvm::make_error<PreviouslyReportedError>(); 4291 4292 SmallVector<Region *> privateCleanupRegions; 4293 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions), 4294 [](omp::PrivateClauseOp privatizer) { 4295 return &privatizer.getDeallocRegion(); 4296 }); 4297 4298 builder.restoreIP(codeGenIP); 4299 llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions( 4300 targetRegion, "omp.target", builder, moduleTranslation); 4301 4302 if (!exitBlock) 4303 return exitBlock.takeError(); 4304 4305 builder.SetInsertPoint(*exitBlock); 4306 if (!privateCleanupRegions.empty()) { 4307 if (failed(inlineOmpRegionCleanup( 4308 privateCleanupRegions, llvmPrivateVars, moduleTranslation, 4309 builder, "omp.targetop.private.cleanup", 4310 /*shouldLoadCleanupRegionArg=*/false))) { 4311 return llvm::createStringError( 4312 "failed to inline `dealloc` region of `omp.private` " 4313 "op in the target region"); 4314 } 4315 } 4316 4317 return InsertPointTy(exitBlock.get(), exitBlock.get()->end()); 4318 }; 4319 4320 StringRef parentName = parentFn.getName(); 4321 4322 llvm::TargetRegionEntryInfo entryInfo; 4323 4324 if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName)) 4325 return failure(); 4326 4327 MapInfoData mapData; 4328 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, 4329 builder); 4330 4331 llvm::OpenMPIRBuilder::MapInfosTy combinedInfos; 4332 auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) 4333 -> llvm::OpenMPIRBuilder::MapInfosTy & { 4334 builder.restoreIP(codeGenIP); 4335 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true); 4336 return combinedInfos; 4337 }; 4338 4339 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input, 4340 llvm::Value *&retVal, InsertPointTy allocaIP, 4341 InsertPointTy codeGenIP) 4342 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { 4343 // We just return the unaltered argument for the host function 4344 // for now, some alterations may be required in the future to 4345 // keep host fallback functions working identically to the device 4346 // version (e.g. pass ByCopy values should be treated as such on 4347 // host and device, currently not always the case) 4348 if (!isTargetDevice) { 4349 retVal = cast<llvm::Value>(&arg); 4350 return codeGenIP; 4351 } 4352 4353 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder, 4354 *ompBuilder, moduleTranslation, 4355 allocaIP, codeGenIP); 4356 }; 4357 4358 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs; 4359 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs; 4360 initTargetDefaultAttrs(targetOp, defaultAttrs, isTargetDevice); 4361 4362 // Collect host-evaluated values needed to properly launch the kernel from the 4363 // host. 4364 if (!isTargetDevice) 4365 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp, runtimeAttrs); 4366 4367 // Pass host-evaluated values as parameters to the kernel / host fallback, 4368 // except if they are constants. In any case, map the MLIR block argument to 4369 // the corresponding LLVM values. 4370 llvm::SmallVector<llvm::Value *, 4> kernelInput; 4371 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars(); 4372 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs(); 4373 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) { 4374 llvm::Value *value = moduleTranslation.lookupValue(var); 4375 moduleTranslation.mapValue(arg, value); 4376 4377 if (!llvm::isa<llvm::Constant>(value)) 4378 kernelInput.push_back(value); 4379 } 4380 4381 for (size_t i = 0; i < mapVars.size(); ++i) { 4382 // declare target arguments are not passed to kernels as arguments 4383 // TODO: We currently do not handle cases where a member is explicitly 4384 // passed in as an argument, this will likley need to be handled in 4385 // the near future, rather than using IsAMember, it may be better to 4386 // test if the relevant BlockArg is used within the target region and 4387 // then use that as a basis for exclusion in the kernel inputs. 4388 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i]) 4389 kernelInput.push_back(mapData.OriginalValue[i]); 4390 } 4391 4392 SmallVector<llvm::OpenMPIRBuilder::DependData> dds; 4393 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(), 4394 moduleTranslation, dds); 4395 4396 llvm::OpenMPIRBuilder::InsertPointTy allocaIP = 4397 findAllocaInsertPoint(builder, moduleTranslation); 4398 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); 4399 4400 llvm::Value *ifCond = nullptr; 4401 if (Value targetIfCond = targetOp.getIfExpr()) 4402 ifCond = moduleTranslation.lookupValue(targetIfCond); 4403 4404 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 4405 moduleTranslation.getOpenMPBuilder()->createTarget( 4406 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, 4407 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB, 4408 argAccessorCB, dds, targetOp.getNowait()); 4409 4410 if (failed(handleError(afterIP, opInst))) 4411 return failure(); 4412 4413 builder.restoreIP(*afterIP); 4414 4415 // Remap access operations to declare target reference pointers for the 4416 // device, essentially generating extra loadop's as necessary 4417 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice()) 4418 handleDeclareTargetMapVar(mapData, moduleTranslation, builder, 4419 llvmOutlinedFn); 4420 4421 return success(); 4422 } 4423 4424 static LogicalResult 4425 convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, 4426 LLVM::ModuleTranslation &moduleTranslation) { 4427 // Amend omp.declare_target by deleting the IR of the outlined functions 4428 // created for target regions. They cannot be filtered out from MLIR earlier 4429 // because the omp.target operation inside must be translated to LLVM, but 4430 // the wrapper functions themselves must not remain at the end of the 4431 // process. We know that functions where omp.declare_target does not match 4432 // omp.is_target_device at this stage can only be wrapper functions because 4433 // those that aren't are removed earlier as an MLIR transformation pass. 4434 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) { 4435 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>( 4436 op->getParentOfType<ModuleOp>().getOperation())) { 4437 if (!offloadMod.getIsTargetDevice()) 4438 return success(); 4439 4440 omp::DeclareTargetDeviceType declareType = 4441 attribute.getDeviceType().getValue(); 4442 4443 if (declareType == omp::DeclareTargetDeviceType::host) { 4444 llvm::Function *llvmFunc = 4445 moduleTranslation.lookupFunction(funcOp.getName()); 4446 llvmFunc->dropAllReferences(); 4447 llvmFunc->eraseFromParent(); 4448 } 4449 } 4450 return success(); 4451 } 4452 4453 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) { 4454 llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); 4455 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) { 4456 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 4457 bool isDeclaration = gOp.isDeclaration(); 4458 bool isExternallyVisible = 4459 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private; 4460 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>(); 4461 llvm::StringRef mangledName = gOp.getSymName(); 4462 auto captureClause = 4463 convertToCaptureClauseKind(attribute.getCaptureClause().getValue()); 4464 auto deviceClause = 4465 convertToDeviceClauseKind(attribute.getDeviceType().getValue()); 4466 // unused for MLIR at the moment, required in Clang for book 4467 // keeping 4468 std::vector<llvm::GlobalVariable *> generatedRefs; 4469 4470 std::vector<llvm::Triple> targetTriple; 4471 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>( 4472 op->getParentOfType<mlir::ModuleOp>()->getAttr( 4473 LLVM::LLVMDialect::getTargetTripleAttrName())); 4474 if (targetTripleAttr) 4475 targetTriple.emplace_back(targetTripleAttr.data()); 4476 4477 auto fileInfoCallBack = [&loc]() { 4478 std::string filename = ""; 4479 std::uint64_t lineNo = 0; 4480 4481 if (loc) { 4482 filename = loc.getFilename().str(); 4483 lineNo = loc.getLine(); 4484 } 4485 4486 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename), 4487 lineNo); 4488 }; 4489 4490 ompBuilder->registerTargetGlobalVariable( 4491 captureClause, deviceClause, isDeclaration, isExternallyVisible, 4492 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName, 4493 generatedRefs, /*OpenMPSimd*/ false, targetTriple, 4494 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr, 4495 gVal->getType(), gVal); 4496 4497 if (ompBuilder->Config.isTargetDevice() && 4498 (attribute.getCaptureClause().getValue() != 4499 mlir::omp::DeclareTargetCaptureClause::to || 4500 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) { 4501 ompBuilder->getAddrOfDeclareTargetVar( 4502 captureClause, deviceClause, isDeclaration, isExternallyVisible, 4503 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName, 4504 generatedRefs, /*OpenMPSimd*/ false, targetTriple, gVal->getType(), 4505 /*GlobalInitializer*/ nullptr, 4506 /*VariableLinkage*/ nullptr); 4507 } 4508 } 4509 } 4510 4511 return success(); 4512 } 4513 4514 // Returns true if the operation is inside a TargetOp or 4515 // is part of a declare target function. 4516 static bool isTargetDeviceOp(Operation *op) { 4517 // Assumes no reverse offloading 4518 if (op->getParentOfType<omp::TargetOp>()) 4519 return true; 4520 4521 // Certain operations return results, and whether utilised in host or 4522 // target there is a chance an LLVM Dialect operation depends on it 4523 // by taking it in as an operand, so we must always lower these in 4524 // some manner or result in an ICE (whether they end up in a no-op 4525 // or otherwise). 4526 if (mlir::isa<omp::ThreadprivateOp>(op)) 4527 return true; 4528 4529 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) 4530 if (auto declareTargetIface = 4531 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( 4532 parentFn.getOperation())) 4533 if (declareTargetIface.isDeclareTarget() && 4534 declareTargetIface.getDeclareTargetDeviceType() != 4535 mlir::omp::DeclareTargetDeviceType::host) 4536 return true; 4537 4538 return false; 4539 } 4540 4541 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR 4542 /// (including OpenMP runtime calls). 4543 static LogicalResult 4544 convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, 4545 LLVM::ModuleTranslation &moduleTranslation) { 4546 4547 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 4548 4549 return llvm::TypeSwitch<Operation *, LogicalResult>(op) 4550 .Case([&](omp::BarrierOp op) -> LogicalResult { 4551 if (failed(checkImplementationStatus(*op))) 4552 return failure(); 4553 4554 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = 4555 ompBuilder->createBarrier(builder.saveIP(), 4556 llvm::omp::OMPD_barrier); 4557 return handleError(afterIP, *op); 4558 }) 4559 .Case([&](omp::TaskyieldOp op) { 4560 if (failed(checkImplementationStatus(*op))) 4561 return failure(); 4562 4563 ompBuilder->createTaskyield(builder.saveIP()); 4564 return success(); 4565 }) 4566 .Case([&](omp::FlushOp op) { 4567 if (failed(checkImplementationStatus(*op))) 4568 return failure(); 4569 4570 // No support in Openmp runtime function (__kmpc_flush) to accept 4571 // the argument list. 4572 // OpenMP standard states the following: 4573 // "An implementation may implement a flush with a list by ignoring 4574 // the list, and treating it the same as a flush without a list." 4575 // 4576 // The argument list is discarded so that, flush with a list is treated 4577 // same as a flush without a list. 4578 ompBuilder->createFlush(builder.saveIP()); 4579 return success(); 4580 }) 4581 .Case([&](omp::ParallelOp op) { 4582 return convertOmpParallel(op, builder, moduleTranslation); 4583 }) 4584 .Case([&](omp::MaskedOp) { 4585 return convertOmpMasked(*op, builder, moduleTranslation); 4586 }) 4587 .Case([&](omp::MasterOp) { 4588 return convertOmpMaster(*op, builder, moduleTranslation); 4589 }) 4590 .Case([&](omp::CriticalOp) { 4591 return convertOmpCritical(*op, builder, moduleTranslation); 4592 }) 4593 .Case([&](omp::OrderedRegionOp) { 4594 return convertOmpOrderedRegion(*op, builder, moduleTranslation); 4595 }) 4596 .Case([&](omp::OrderedOp) { 4597 return convertOmpOrdered(*op, builder, moduleTranslation); 4598 }) 4599 .Case([&](omp::WsloopOp) { 4600 return convertOmpWsloop(*op, builder, moduleTranslation); 4601 }) 4602 .Case([&](omp::SimdOp) { 4603 return convertOmpSimd(*op, builder, moduleTranslation); 4604 }) 4605 .Case([&](omp::AtomicReadOp) { 4606 return convertOmpAtomicRead(*op, builder, moduleTranslation); 4607 }) 4608 .Case([&](omp::AtomicWriteOp) { 4609 return convertOmpAtomicWrite(*op, builder, moduleTranslation); 4610 }) 4611 .Case([&](omp::AtomicUpdateOp op) { 4612 return convertOmpAtomicUpdate(op, builder, moduleTranslation); 4613 }) 4614 .Case([&](omp::AtomicCaptureOp op) { 4615 return convertOmpAtomicCapture(op, builder, moduleTranslation); 4616 }) 4617 .Case([&](omp::SectionsOp) { 4618 return convertOmpSections(*op, builder, moduleTranslation); 4619 }) 4620 .Case([&](omp::SingleOp op) { 4621 return convertOmpSingle(op, builder, moduleTranslation); 4622 }) 4623 .Case([&](omp::TeamsOp op) { 4624 return convertOmpTeams(op, builder, moduleTranslation); 4625 }) 4626 .Case([&](omp::TaskOp op) { 4627 return convertOmpTaskOp(op, builder, moduleTranslation); 4628 }) 4629 .Case([&](omp::TaskgroupOp op) { 4630 return convertOmpTaskgroupOp(op, builder, moduleTranslation); 4631 }) 4632 .Case([&](omp::TaskwaitOp op) { 4633 return convertOmpTaskwaitOp(op, builder, moduleTranslation); 4634 }) 4635 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp, 4636 omp::CriticalDeclareOp>([](auto op) { 4637 // `yield` and `terminator` can be just omitted. The block structure 4638 // was created in the region that handles their parent operation. 4639 // `declare_reduction` will be used by reductions and is not 4640 // converted directly, skip it. 4641 // `critical.declare` is only used to declare names of critical 4642 // sections which will be used by `critical` ops and hence can be 4643 // ignored for lowering. The OpenMP IRBuilder will create unique 4644 // name for critical section names. 4645 return success(); 4646 }) 4647 .Case([&](omp::ThreadprivateOp) { 4648 return convertOmpThreadprivate(*op, builder, moduleTranslation); 4649 }) 4650 .Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp, 4651 omp::TargetUpdateOp>([&](auto op) { 4652 return convertOmpTargetData(op, builder, moduleTranslation); 4653 }) 4654 .Case([&](omp::TargetOp) { 4655 return convertOmpTarget(*op, builder, moduleTranslation); 4656 }) 4657 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>( 4658 [&](auto op) { 4659 // No-op, should be handled by relevant owning operations e.g. 4660 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp etc. 4661 // and then discarded 4662 return success(); 4663 }) 4664 .Default([&](Operation *inst) { 4665 return inst->emitError() << "not yet implemented: " << inst->getName(); 4666 }); 4667 } 4668 4669 static LogicalResult 4670 convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, 4671 LLVM::ModuleTranslation &moduleTranslation) { 4672 return convertHostOrTargetOperation(op, builder, moduleTranslation); 4673 } 4674 4675 static LogicalResult 4676 convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, 4677 LLVM::ModuleTranslation &moduleTranslation) { 4678 if (isa<omp::TargetOp>(op)) 4679 return convertOmpTarget(*op, builder, moduleTranslation); 4680 if (isa<omp::TargetDataOp>(op)) 4681 return convertOmpTargetData(op, builder, moduleTranslation); 4682 bool interrupted = 4683 op->walk<WalkOrder::PreOrder>([&](Operation *oper) { 4684 if (isa<omp::TargetOp>(oper)) { 4685 if (failed(convertOmpTarget(*oper, builder, moduleTranslation))) 4686 return WalkResult::interrupt(); 4687 return WalkResult::skip(); 4688 } 4689 if (isa<omp::TargetDataOp>(oper)) { 4690 if (failed(convertOmpTargetData(oper, builder, moduleTranslation))) 4691 return WalkResult::interrupt(); 4692 return WalkResult::skip(); 4693 } 4694 return WalkResult::advance(); 4695 }).wasInterrupted(); 4696 return failure(interrupted); 4697 } 4698 4699 namespace { 4700 4701 /// Implementation of the dialect interface that converts operations belonging 4702 /// to the OpenMP dialect to LLVM IR. 4703 class OpenMPDialectLLVMIRTranslationInterface 4704 : public LLVMTranslationDialectInterface { 4705 public: 4706 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 4707 4708 /// Translates the given operation to LLVM IR using the provided IR builder 4709 /// and saving the state in `moduleTranslation`. 4710 LogicalResult 4711 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 4712 LLVM::ModuleTranslation &moduleTranslation) const final; 4713 4714 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, 4715 /// runtime calls, or operation amendments 4716 LogicalResult 4717 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, 4718 NamedAttribute attribute, 4719 LLVM::ModuleTranslation &moduleTranslation) const final; 4720 }; 4721 4722 } // namespace 4723 4724 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( 4725 Operation *op, ArrayRef<llvm::Instruction *> instructions, 4726 NamedAttribute attribute, 4727 LLVM::ModuleTranslation &moduleTranslation) const { 4728 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>( 4729 attribute.getName()) 4730 .Case("omp.is_target_device", 4731 [&](Attribute attr) { 4732 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) { 4733 llvm::OpenMPIRBuilderConfig &config = 4734 moduleTranslation.getOpenMPBuilder()->Config; 4735 config.setIsTargetDevice(deviceAttr.getValue()); 4736 return success(); 4737 } 4738 return failure(); 4739 }) 4740 .Case("omp.is_gpu", 4741 [&](Attribute attr) { 4742 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) { 4743 llvm::OpenMPIRBuilderConfig &config = 4744 moduleTranslation.getOpenMPBuilder()->Config; 4745 config.setIsGPU(gpuAttr.getValue()); 4746 return success(); 4747 } 4748 return failure(); 4749 }) 4750 .Case("omp.host_ir_filepath", 4751 [&](Attribute attr) { 4752 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) { 4753 llvm::OpenMPIRBuilder *ompBuilder = 4754 moduleTranslation.getOpenMPBuilder(); 4755 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue()); 4756 return success(); 4757 } 4758 return failure(); 4759 }) 4760 .Case("omp.flags", 4761 [&](Attribute attr) { 4762 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr)) 4763 return convertFlagsAttr(op, rtlAttr, moduleTranslation); 4764 return failure(); 4765 }) 4766 .Case("omp.version", 4767 [&](Attribute attr) { 4768 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) { 4769 llvm::OpenMPIRBuilder *ompBuilder = 4770 moduleTranslation.getOpenMPBuilder(); 4771 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp", 4772 versionAttr.getVersion()); 4773 return success(); 4774 } 4775 return failure(); 4776 }) 4777 .Case("omp.declare_target", 4778 [&](Attribute attr) { 4779 if (auto declareTargetAttr = 4780 dyn_cast<omp::DeclareTargetAttr>(attr)) 4781 return convertDeclareTargetAttr(op, declareTargetAttr, 4782 moduleTranslation); 4783 return failure(); 4784 }) 4785 .Case("omp.requires", 4786 [&](Attribute attr) { 4787 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) { 4788 using Requires = omp::ClauseRequires; 4789 Requires flags = requiresAttr.getValue(); 4790 llvm::OpenMPIRBuilderConfig &config = 4791 moduleTranslation.getOpenMPBuilder()->Config; 4792 config.setHasRequiresReverseOffload( 4793 bitEnumContainsAll(flags, Requires::reverse_offload)); 4794 config.setHasRequiresUnifiedAddress( 4795 bitEnumContainsAll(flags, Requires::unified_address)); 4796 config.setHasRequiresUnifiedSharedMemory( 4797 bitEnumContainsAll(flags, Requires::unified_shared_memory)); 4798 config.setHasRequiresDynamicAllocators( 4799 bitEnumContainsAll(flags, Requires::dynamic_allocators)); 4800 return success(); 4801 } 4802 return failure(); 4803 }) 4804 .Case("omp.target_triples", 4805 [&](Attribute attr) { 4806 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) { 4807 llvm::OpenMPIRBuilderConfig &config = 4808 moduleTranslation.getOpenMPBuilder()->Config; 4809 config.TargetTriples.clear(); 4810 config.TargetTriples.reserve(triplesAttr.size()); 4811 for (Attribute tripleAttr : triplesAttr) { 4812 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr)) 4813 config.TargetTriples.emplace_back(tripleStrAttr.getValue()); 4814 else 4815 return failure(); 4816 } 4817 return success(); 4818 } 4819 return failure(); 4820 }) 4821 .Default([](Attribute) { 4822 // Fall through for omp attributes that do not require lowering. 4823 return success(); 4824 })(attribute.getValue()); 4825 4826 return failure(); 4827 } 4828 4829 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR 4830 /// (including OpenMP runtime calls). 4831 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( 4832 Operation *op, llvm::IRBuilderBase &builder, 4833 LLVM::ModuleTranslation &moduleTranslation) const { 4834 4835 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 4836 if (ompBuilder->Config.isTargetDevice()) { 4837 if (isTargetDeviceOp(op)) { 4838 return convertTargetDeviceOp(op, builder, moduleTranslation); 4839 } else { 4840 return convertTargetOpsInNest(op, builder, moduleTranslation); 4841 } 4842 } 4843 return convertHostOrTargetOperation(op, builder, moduleTranslation); 4844 } 4845 4846 void mlir::registerOpenMPDialectTranslation(DialectRegistry ®istry) { 4847 registry.insert<omp::OpenMPDialect>(); 4848 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) { 4849 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>(); 4850 }); 4851 } 4852 4853 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) { 4854 DialectRegistry registry; 4855 registerOpenMPDialectTranslation(registry); 4856 context.appendDialectRegistry(registry); 4857 } 4858