1 //===- LowerWorkshare.cpp - special cases for bufferization -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the lowering of omp.workshare to other omp constructs. 10 // 11 // This pass is tasked with parallelizing the loops nested in 12 // workshare.loop_wrapper while both the Fortran to mlir lowering and the hlfir 13 // to fir lowering pipelines are responsible for emitting the 14 // workshare.loop_wrapper ops where appropriate according to the 15 // `shouldUseWorkshareLowering` function. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include <flang/Optimizer/Builder/FIRBuilder.h> 20 #include <flang/Optimizer/Dialect/FIROps.h> 21 #include <flang/Optimizer/Dialect/FIRType.h> 22 #include <flang/Optimizer/HLFIR/HLFIROps.h> 23 #include <flang/Optimizer/OpenMP/Passes.h> 24 #include <llvm/ADT/BreadthFirstIterator.h> 25 #include <llvm/ADT/STLExtras.h> 26 #include <llvm/ADT/SmallVectorExtras.h> 27 #include <llvm/ADT/iterator_range.h> 28 #include <llvm/Support/ErrorHandling.h> 29 #include <mlir/Dialect/Arith/IR/Arith.h> 30 #include <mlir/Dialect/LLVMIR/LLVMTypes.h> 31 #include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h> 32 #include <mlir/Dialect/OpenMP/OpenMPDialect.h> 33 #include <mlir/Dialect/SCF/IR/SCF.h> 34 #include <mlir/IR/BuiltinOps.h> 35 #include <mlir/IR/IRMapping.h> 36 #include <mlir/IR/OpDefinition.h> 37 #include <mlir/IR/PatternMatch.h> 38 #include <mlir/IR/Value.h> 39 #include <mlir/IR/Visitors.h> 40 #include <mlir/Interfaces/SideEffectInterfaces.h> 41 #include <mlir/Support/LLVM.h> 42 43 #include <variant> 44 45 namespace flangomp { 46 #define GEN_PASS_DEF_LOWERWORKSHARE 47 #include "flang/Optimizer/OpenMP/Passes.h.inc" 48 } // namespace flangomp 49 50 #define DEBUG_TYPE "lower-workshare" 51 52 using namespace mlir; 53 54 namespace flangomp { 55 56 // Checks for nesting pattern below as we need to avoid sharing the work of 57 // statements which are nested in some constructs such as omp.critical or 58 // another omp.parallel. 59 // 60 // omp.workshare { // `wsOp` 61 // ... 62 // omp.T { // `parent` 63 // ... 64 // `op` 65 // 66 template <typename T> 67 static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) { 68 T parent = op->getParentOfType<T>(); 69 if (!parent) 70 return false; 71 return wsOp->isProperAncestor(parent); 72 } 73 74 bool shouldUseWorkshareLowering(Operation *op) { 75 auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>(); 76 77 if (!parentWorkshare) 78 return false; 79 80 if (isNestedIn<omp::CriticalOp>(parentWorkshare, op)) 81 return false; 82 83 // 2.8.3 workshare Construct 84 // For a parallel construct, the construct is a unit of work with respect to 85 // the workshare construct. The statements contained in the parallel construct 86 // are executed by a new thread team. 87 if (isNestedIn<omp::ParallelOp>(parentWorkshare, op)) 88 return false; 89 90 // 2.8.2 single Construct 91 // Binding The binding thread set for a single region is the current team. A 92 // single region binds to the innermost enclosing parallel region. 93 // Description Only one of the encountering threads will execute the 94 // structured block associated with the single construct. 95 if (isNestedIn<omp::SingleOp>(parentWorkshare, op)) 96 return false; 97 98 // Do not use workshare lowering until we support CFG in omp.workshare 99 if (parentWorkshare.getRegion().getBlocks().size() != 1) 100 return false; 101 102 return true; 103 } 104 105 } // namespace flangomp 106 107 namespace { 108 109 struct SingleRegion { 110 Block::iterator begin, end; 111 }; 112 113 static bool mustParallelizeOp(Operation *op) { 114 return op 115 ->walk([&](Operation *nested) { 116 // We need to be careful not to pick up workshare.loop_wrapper in nested 117 // omp.parallel{omp.workshare} regions, i.e. make sure that `nested` 118 // binds to the workshare region we are currently handling. 119 // 120 // For example: 121 // 122 // omp.parallel { 123 // omp.workshare { // currently handling this 124 // omp.parallel { 125 // omp.workshare { // nested workshare 126 // omp.workshare.loop_wrapper {} 127 // 128 // Therefore, we skip if we encounter a nested omp.workshare. 129 if (isa<omp::WorkshareOp>(nested)) 130 return WalkResult::skip(); 131 if (isa<omp::WorkshareLoopWrapperOp>(nested)) 132 return WalkResult::interrupt(); 133 return WalkResult::advance(); 134 }) 135 .wasInterrupted(); 136 } 137 138 static bool isSafeToParallelize(Operation *op) { 139 return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) || 140 isMemoryEffectFree(op); 141 } 142 143 /// Simple shallow copies suffice for our purposes in this pass, so we implement 144 /// this simpler alternative to the full fledged `createCopyFunc` in the 145 /// frontend 146 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType, 147 fir::FirOpBuilder builder) { 148 mlir::ModuleOp module = builder.getModule(); 149 auto rt = cast<fir::ReferenceType>(varType); 150 mlir::Type eleTy = rt.getEleTy(); 151 std::string copyFuncName = 152 fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy"); 153 154 if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName)) 155 return decl; 156 157 // create function 158 mlir::OpBuilder::InsertionGuard guard(builder); 159 mlir::OpBuilder modBuilder(module.getBodyRegion()); 160 llvm::SmallVector<mlir::Type> argsTy = {varType, varType}; 161 auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {}); 162 mlir::func::FuncOp funcOp = 163 modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType); 164 funcOp.setVisibility(mlir::SymbolTable::Visibility::Private); 165 fir::factory::setInternalLinkage(funcOp); 166 builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy, 167 {loc, loc}); 168 builder.setInsertionPointToStart(&funcOp.getRegion().back()); 169 170 Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(1)); 171 builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(0)); 172 173 builder.create<mlir::func::ReturnOp>(loc); 174 return funcOp; 175 } 176 177 static bool isUserOutsideSR(Operation *user, Operation *parentOp, 178 SingleRegion sr) { 179 while (user->getParentOp() != parentOp) 180 user = user->getParentOp(); 181 return sr.begin->getBlock() != user->getBlock() || 182 !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user)); 183 } 184 185 static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) { 186 Block *srBlock = sr.begin->getBlock(); 187 Operation *parentOp = srBlock->getParentOp(); 188 189 for (auto &use : v.getUses()) { 190 Operation *user = use.getOwner(); 191 if (isUserOutsideSR(user, parentOp, sr)) 192 return true; 193 194 // Now we know user is inside `sr`. 195 196 // Results of nested users cannot be used outside of `sr`. 197 if (user->getBlock() != srBlock) 198 continue; 199 200 // A non-safe to parallelize operation will be checked for uses outside 201 // separately. 202 if (!isSafeToParallelize(user)) 203 continue; 204 205 // For safe to parallelize operations, we need to check if there is a 206 // transitive use of `v` through them. 207 for (auto res : user->getResults()) 208 if (isTransitivelyUsedOutside(res, sr)) 209 return true; 210 } 211 return false; 212 } 213 214 /// We clone pure operations in both the parallel and single blocks. this 215 /// functions cleans them up if they end up with no uses 216 static void cleanupBlock(Block *block) { 217 for (Operation &op : llvm::make_early_inc_range( 218 llvm::make_range(block->rbegin(), block->rend()))) 219 if (isOpTriviallyDead(&op)) 220 op.erase(); 221 } 222 223 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, 224 IRMapping &rootMapping, Location loc, 225 mlir::DominanceInfo &di) { 226 OpBuilder rootBuilder(sourceRegion.getContext()); 227 ModuleOp m = sourceRegion.getParentOfType<ModuleOp>(); 228 OpBuilder copyFuncBuilder(m.getBodyRegion()); 229 fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m); 230 231 auto mapReloadedValue = 232 [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder, 233 OpBuilder parallelBuilder, IRMapping singleMapping) -> Value { 234 if (auto reloaded = rootMapping.lookupOrNull(v)) 235 return nullptr; 236 Type ty = v.getType(); 237 Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty); 238 singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc); 239 Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc); 240 rootMapping.map(v, reloaded); 241 return alloc; 242 }; 243 244 auto moveToSingle = 245 [&](SingleRegion sr, OpBuilder allocaBuilder, OpBuilder singleBuilder, 246 OpBuilder parallelBuilder) -> std::pair<bool, SmallVector<Value>> { 247 IRMapping singleMapping = rootMapping; 248 SmallVector<Value> copyPrivate; 249 bool allParallelized = true; 250 251 for (Operation &op : llvm::make_range(sr.begin, sr.end)) { 252 if (isSafeToParallelize(&op)) { 253 singleBuilder.clone(op, singleMapping); 254 if (llvm::all_of(op.getOperands(), [&](Value opr) { 255 // Either we have already remapped it 256 bool remapped = rootMapping.contains(opr); 257 // Or it is available because it dominates `sr` 258 bool dominates = di.properlyDominates(opr, &*sr.begin); 259 return remapped || dominates; 260 })) { 261 // Safe to parallelize operations which have all operands available in 262 // the root parallel block can be executed there. 263 parallelBuilder.clone(op, rootMapping); 264 } else { 265 // If any operand was not available, it means that there was no 266 // transitive use of a non-safe-to-parallelize operation outside `sr`. 267 // This means that there should be no transitive uses outside `sr` of 268 // `op`. 269 assert(llvm::all_of(op.getResults(), [&](Value v) { 270 return !isTransitivelyUsedOutside(v, sr); 271 })); 272 allParallelized = false; 273 } 274 } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) { 275 auto hoisted = 276 cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping)); 277 rootMapping.map(&*alloca, &*hoisted); 278 rootMapping.map(alloca.getResult(), hoisted.getResult()); 279 copyPrivate.push_back(hoisted); 280 allParallelized = false; 281 } else { 282 singleBuilder.clone(op, singleMapping); 283 // Prepare reloaded values for results of operations that cannot be 284 // safely parallelized and which are used after the region `sr`. 285 for (auto res : op.getResults()) { 286 if (isTransitivelyUsedOutside(res, sr)) { 287 auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder, 288 parallelBuilder, singleMapping); 289 if (alloc) 290 copyPrivate.push_back(alloc); 291 } 292 } 293 allParallelized = false; 294 } 295 } 296 singleBuilder.create<omp::TerminatorOp>(loc); 297 return {allParallelized, copyPrivate}; 298 }; 299 300 for (Block &block : sourceRegion) { 301 Block *targetBlock = rootBuilder.createBlock( 302 &targetRegion, {}, block.getArgumentTypes(), 303 llvm::map_to_vector(block.getArguments(), 304 [](BlockArgument arg) { return arg.getLoc(); })); 305 rootMapping.map(&block, targetBlock); 306 rootMapping.map(block.getArguments(), targetBlock->getArguments()); 307 } 308 309 auto handleOneBlock = [&](Block &block) { 310 Block &targetBlock = *rootMapping.lookup(&block); 311 rootBuilder.setInsertionPointToStart(&targetBlock); 312 Operation *terminator = block.getTerminator(); 313 SmallVector<std::variant<SingleRegion, Operation *>> regions; 314 315 auto it = block.begin(); 316 auto getOneRegion = [&]() { 317 if (&*it == terminator) 318 return false; 319 if (mustParallelizeOp(&*it)) { 320 regions.push_back(&*it); 321 it++; 322 return true; 323 } 324 SingleRegion sr; 325 sr.begin = it; 326 while (&*it != terminator && !mustParallelizeOp(&*it)) 327 it++; 328 sr.end = it; 329 assert(sr.begin != sr.end); 330 regions.push_back(sr); 331 return true; 332 }; 333 while (getOneRegion()) 334 ; 335 336 for (auto [i, opOrSingle] : llvm::enumerate(regions)) { 337 bool isLast = i + 1 == regions.size(); 338 if (std::holds_alternative<SingleRegion>(opOrSingle)) { 339 OpBuilder singleBuilder(sourceRegion.getContext()); 340 Block *singleBlock = new Block(); 341 singleBuilder.setInsertionPointToStart(singleBlock); 342 343 OpBuilder allocaBuilder(sourceRegion.getContext()); 344 Block *allocaBlock = new Block(); 345 allocaBuilder.setInsertionPointToStart(allocaBlock); 346 347 OpBuilder parallelBuilder(sourceRegion.getContext()); 348 Block *parallelBlock = new Block(); 349 parallelBuilder.setInsertionPointToStart(parallelBlock); 350 351 auto [allParallelized, copyprivateVars] = 352 moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder, 353 singleBuilder, parallelBuilder); 354 if (allParallelized) { 355 // The single region was not required as all operations were safe to 356 // parallelize 357 assert(copyprivateVars.empty()); 358 assert(allocaBlock->empty()); 359 delete singleBlock; 360 } else { 361 omp::SingleOperands singleOperands; 362 if (isLast) 363 singleOperands.nowait = rootBuilder.getUnitAttr(); 364 singleOperands.copyprivateVars = copyprivateVars; 365 cleanupBlock(singleBlock); 366 for (auto var : singleOperands.copyprivateVars) { 367 mlir::func::FuncOp funcOp = 368 createCopyFunc(loc, var.getType(), firCopyFuncBuilder); 369 singleOperands.copyprivateSyms.push_back( 370 SymbolRefAttr::get(funcOp)); 371 } 372 omp::SingleOp singleOp = 373 rootBuilder.create<omp::SingleOp>(loc, singleOperands); 374 singleOp.getRegion().push_back(singleBlock); 375 targetRegion.front().getOperations().splice( 376 singleOp->getIterator(), allocaBlock->getOperations()); 377 } 378 rootBuilder.getInsertionBlock()->getOperations().splice( 379 rootBuilder.getInsertionPoint(), parallelBlock->getOperations()); 380 delete allocaBlock; 381 delete parallelBlock; 382 } else { 383 auto op = std::get<Operation *>(opOrSingle); 384 if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) { 385 omp::WsloopOperands wsloopOperands; 386 if (isLast) 387 wsloopOperands.nowait = rootBuilder.getUnitAttr(); 388 auto wsloop = 389 rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands); 390 auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>( 391 rootBuilder.clone(*wslw, rootMapping)); 392 wsloop.getRegion().takeBody(clonedWslw.getRegion()); 393 clonedWslw->erase(); 394 } else { 395 assert(mustParallelizeOp(op)); 396 Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping); 397 for (auto [region, clonedRegion] : 398 llvm::zip(op->getRegions(), cloned->getRegions())) 399 parallelizeRegion(region, clonedRegion, rootMapping, loc, di); 400 } 401 } 402 } 403 404 rootBuilder.clone(*block.getTerminator(), rootMapping); 405 }; 406 407 if (sourceRegion.hasOneBlock()) { 408 handleOneBlock(sourceRegion.front()); 409 } else if (!sourceRegion.empty()) { 410 auto &domTree = di.getDomTree(&sourceRegion); 411 for (auto node : llvm::breadth_first(domTree.getRootNode())) { 412 handleOneBlock(*node->getBlock()); 413 } 414 } 415 416 for (Block &targetBlock : targetRegion) 417 cleanupBlock(&targetBlock); 418 } 419 420 /// Lowers workshare to a sequence of single-thread regions and parallel loops 421 /// 422 /// For example: 423 /// 424 /// omp.workshare { 425 /// %a = fir.allocmem 426 /// omp.workshare.loop_wrapper {} 427 /// fir.call Assign %b %a 428 /// fir.freemem %a 429 /// } 430 /// 431 /// becomes 432 /// 433 /// %tmp = fir.alloca 434 /// omp.single copyprivate(%tmp) { 435 /// %a = fir.allocmem 436 /// fir.store %a %tmp 437 /// } 438 /// %a_reloaded = fir.load %tmp 439 /// omp.workshare.loop_wrapper {} 440 /// omp.single { 441 /// fir.call Assign %b %a_reloaded 442 /// fir.freemem %a_reloaded 443 /// } 444 /// 445 /// Note that we allocate temporary memory for values in omp.single's which need 446 /// to be accessed by all threads and broadcast them using single's copyprivate 447 LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) { 448 Location loc = wsOp->getLoc(); 449 IRMapping rootMapping; 450 451 OpBuilder rootBuilder(wsOp); 452 453 // FIXME Currently, we only support workshare constructs with structured 454 // control flow. The transformation itself supports CFG, however, once we 455 // transform the MLIR region in the omp.workshare, we need to inline that 456 // region in the parent block. We have no guarantees at this point of the 457 // pipeline that the parent op supports CFG (e.g. fir.if), thus this is not 458 // generally possible. The alternative is to put the lowered region in an 459 // operation akin to scf.execute_region, which will get lowered at the same 460 // time when fir ops get lowered to CFG. However, SCF is not registered in 461 // flang so we cannot use it. Remove this requirement once we have 462 // scf.execute_region or an alternative operation available. 463 if (wsOp.getRegion().getBlocks().size() == 1) { 464 // This operation is just a placeholder which will be erased later. We need 465 // it because our `parallelizeRegion` function works on regions and not 466 // blocks. 467 omp::WorkshareOp newOp = 468 rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands()); 469 if (!wsOp.getNowait()) 470 rootBuilder.create<omp::BarrierOp>(loc); 471 472 parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc, 473 di); 474 475 // Inline the contents of the placeholder workshare op into its parent 476 // block. 477 Block *theBlock = &newOp.getRegion().front(); 478 Operation *term = theBlock->getTerminator(); 479 Block *parentBlock = wsOp->getBlock(); 480 parentBlock->getOperations().splice(newOp->getIterator(), 481 theBlock->getOperations()); 482 assert(term->getNumOperands() == 0); 483 term->erase(); 484 newOp->erase(); 485 wsOp->erase(); 486 } else { 487 // Otherwise just change the operation to an omp.single. 488 489 wsOp->emitWarning( 490 "omp workshare with unstructured control flow is currently " 491 "unsupported and will be serialized."); 492 493 // `shouldUseWorkshareLowering` should have guaranteed that there are no 494 // omp.workshare_loop_wrapper's that bind to this omp.workshare. 495 assert(!wsOp->walk([&](Operation *op) { 496 // Nested omp.workshare can have their own 497 // omp.workshare_loop_wrapper's. 498 if (isa<omp::WorkshareOp>(op)) 499 return WalkResult::skip(); 500 if (isa<omp::WorkshareLoopWrapperOp>(op)) 501 return WalkResult::interrupt(); 502 return WalkResult::advance(); 503 }) 504 .wasInterrupted()); 505 506 omp::SingleOperands operands; 507 operands.nowait = wsOp.getNowaitAttr(); 508 omp::SingleOp newOp = rootBuilder.create<omp::SingleOp>(loc, operands); 509 510 newOp.getRegion().getBlocks().splice(newOp.getRegion().getBlocks().begin(), 511 wsOp.getRegion().getBlocks()); 512 wsOp->erase(); 513 } 514 return success(); 515 } 516 517 class LowerWorksharePass 518 : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> { 519 public: 520 void runOnOperation() override { 521 mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>(); 522 getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) { 523 if (failed(lowerWorkshare(wsOp, di))) 524 signalPassFailure(); 525 }); 526 } 527 }; 528 } // namespace 529