xref: /llvm-project/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp (revision 261a4026e8a367eda229879838709b92abaf445c)
1 //===- LowerWorkshare.cpp - special cases for bufferization -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of omp.workshare to other omp constructs.
10 //
11 // This pass is tasked with parallelizing the loops nested in
12 // workshare.loop_wrapper while both the Fortran to mlir lowering and the hlfir
13 // to fir lowering pipelines are responsible for emitting the
14 // workshare.loop_wrapper ops where appropriate according to the
15 // `shouldUseWorkshareLowering` function.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include <flang/Optimizer/Builder/FIRBuilder.h>
20 #include <flang/Optimizer/Dialect/FIROps.h>
21 #include <flang/Optimizer/Dialect/FIRType.h>
22 #include <flang/Optimizer/HLFIR/HLFIROps.h>
23 #include <flang/Optimizer/OpenMP/Passes.h>
24 #include <llvm/ADT/BreadthFirstIterator.h>
25 #include <llvm/ADT/STLExtras.h>
26 #include <llvm/ADT/SmallVectorExtras.h>
27 #include <llvm/ADT/iterator_range.h>
28 #include <llvm/Support/ErrorHandling.h>
29 #include <mlir/Dialect/Arith/IR/Arith.h>
30 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
31 #include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
32 #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
33 #include <mlir/Dialect/SCF/IR/SCF.h>
34 #include <mlir/IR/BuiltinOps.h>
35 #include <mlir/IR/IRMapping.h>
36 #include <mlir/IR/OpDefinition.h>
37 #include <mlir/IR/PatternMatch.h>
38 #include <mlir/IR/Value.h>
39 #include <mlir/IR/Visitors.h>
40 #include <mlir/Interfaces/SideEffectInterfaces.h>
41 #include <mlir/Support/LLVM.h>
42 
43 #include <variant>
44 
45 namespace flangomp {
46 #define GEN_PASS_DEF_LOWERWORKSHARE
47 #include "flang/Optimizer/OpenMP/Passes.h.inc"
48 } // namespace flangomp
49 
50 #define DEBUG_TYPE "lower-workshare"
51 
52 using namespace mlir;
53 
54 namespace flangomp {
55 
56 // Checks for nesting pattern below as we need to avoid sharing the work of
57 // statements which are nested in some constructs such as omp.critical or
58 // another omp.parallel.
59 //
60 // omp.workshare { // `wsOp`
61 //   ...
62 //     omp.T { // `parent`
63 //       ...
64 //         `op`
65 //
66 template <typename T>
67 static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) {
68   T parent = op->getParentOfType<T>();
69   if (!parent)
70     return false;
71   return wsOp->isProperAncestor(parent);
72 }
73 
74 bool shouldUseWorkshareLowering(Operation *op) {
75   auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>();
76 
77   if (!parentWorkshare)
78     return false;
79 
80   if (isNestedIn<omp::CriticalOp>(parentWorkshare, op))
81     return false;
82 
83   // 2.8.3  workshare Construct
84   // For a parallel construct, the construct is a unit of work with respect to
85   // the workshare construct. The statements contained in the parallel construct
86   // are executed by a new thread team.
87   if (isNestedIn<omp::ParallelOp>(parentWorkshare, op))
88     return false;
89 
90   // 2.8.2  single Construct
91   // Binding The binding thread set for a single region is the current team. A
92   // single region binds to the innermost enclosing parallel region.
93   // Description Only one of the encountering threads will execute the
94   // structured block associated with the single construct.
95   if (isNestedIn<omp::SingleOp>(parentWorkshare, op))
96     return false;
97 
98   // Do not use workshare lowering until we support CFG in omp.workshare
99   if (parentWorkshare.getRegion().getBlocks().size() != 1)
100     return false;
101 
102   return true;
103 }
104 
105 } // namespace flangomp
106 
107 namespace {
108 
109 struct SingleRegion {
110   Block::iterator begin, end;
111 };
112 
113 static bool mustParallelizeOp(Operation *op) {
114   return op
115       ->walk([&](Operation *nested) {
116         // We need to be careful not to pick up workshare.loop_wrapper in nested
117         // omp.parallel{omp.workshare} regions, i.e. make sure that `nested`
118         // binds to the workshare region we are currently handling.
119         //
120         // For example:
121         //
122         // omp.parallel {
123         //   omp.workshare { // currently handling this
124         //     omp.parallel {
125         //       omp.workshare { // nested workshare
126         //         omp.workshare.loop_wrapper {}
127         //
128         // Therefore, we skip if we encounter a nested omp.workshare.
129         if (isa<omp::WorkshareOp>(nested))
130           return WalkResult::skip();
131         if (isa<omp::WorkshareLoopWrapperOp>(nested))
132           return WalkResult::interrupt();
133         return WalkResult::advance();
134       })
135       .wasInterrupted();
136 }
137 
138 static bool isSafeToParallelize(Operation *op) {
139   return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
140          isMemoryEffectFree(op);
141 }
142 
143 /// Simple shallow copies suffice for our purposes in this pass, so we implement
144 /// this simpler alternative to the full fledged `createCopyFunc` in the
145 /// frontend
146 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
147                                          fir::FirOpBuilder builder) {
148   mlir::ModuleOp module = builder.getModule();
149   auto rt = cast<fir::ReferenceType>(varType);
150   mlir::Type eleTy = rt.getEleTy();
151   std::string copyFuncName =
152       fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
153 
154   if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
155     return decl;
156 
157   // create function
158   mlir::OpBuilder::InsertionGuard guard(builder);
159   mlir::OpBuilder modBuilder(module.getBodyRegion());
160   llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
161   auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
162   mlir::func::FuncOp funcOp =
163       modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
164   funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
165   fir::factory::setInternalLinkage(funcOp);
166   builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
167                       {loc, loc});
168   builder.setInsertionPointToStart(&funcOp.getRegion().back());
169 
170   Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(1));
171   builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(0));
172 
173   builder.create<mlir::func::ReturnOp>(loc);
174   return funcOp;
175 }
176 
177 static bool isUserOutsideSR(Operation *user, Operation *parentOp,
178                             SingleRegion sr) {
179   while (user->getParentOp() != parentOp)
180     user = user->getParentOp();
181   return sr.begin->getBlock() != user->getBlock() ||
182          !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
183 }
184 
185 static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
186   Block *srBlock = sr.begin->getBlock();
187   Operation *parentOp = srBlock->getParentOp();
188 
189   for (auto &use : v.getUses()) {
190     Operation *user = use.getOwner();
191     if (isUserOutsideSR(user, parentOp, sr))
192       return true;
193 
194     // Now we know user is inside `sr`.
195 
196     // Results of nested users cannot be used outside of `sr`.
197     if (user->getBlock() != srBlock)
198       continue;
199 
200     // A non-safe to parallelize operation will be checked for uses outside
201     // separately.
202     if (!isSafeToParallelize(user))
203       continue;
204 
205     // For safe to parallelize operations, we need to check if there is a
206     // transitive use of `v` through them.
207     for (auto res : user->getResults())
208       if (isTransitivelyUsedOutside(res, sr))
209         return true;
210   }
211   return false;
212 }
213 
214 /// We clone pure operations in both the parallel and single blocks. this
215 /// functions cleans them up if they end up with no uses
216 static void cleanupBlock(Block *block) {
217   for (Operation &op : llvm::make_early_inc_range(
218            llvm::make_range(block->rbegin(), block->rend())))
219     if (isOpTriviallyDead(&op))
220       op.erase();
221 }
222 
223 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
224                               IRMapping &rootMapping, Location loc,
225                               mlir::DominanceInfo &di) {
226   OpBuilder rootBuilder(sourceRegion.getContext());
227   ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
228   OpBuilder copyFuncBuilder(m.getBodyRegion());
229   fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
230 
231   auto mapReloadedValue =
232       [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
233           OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
234     if (auto reloaded = rootMapping.lookupOrNull(v))
235       return nullptr;
236     Type ty = v.getType();
237     Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
238     singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
239     Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
240     rootMapping.map(v, reloaded);
241     return alloc;
242   };
243 
244   auto moveToSingle =
245       [&](SingleRegion sr, OpBuilder allocaBuilder, OpBuilder singleBuilder,
246           OpBuilder parallelBuilder) -> std::pair<bool, SmallVector<Value>> {
247     IRMapping singleMapping = rootMapping;
248     SmallVector<Value> copyPrivate;
249     bool allParallelized = true;
250 
251     for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
252       if (isSafeToParallelize(&op)) {
253         singleBuilder.clone(op, singleMapping);
254         if (llvm::all_of(op.getOperands(), [&](Value opr) {
255               // Either we have already remapped it
256               bool remapped = rootMapping.contains(opr);
257               // Or it is available because it dominates `sr`
258               bool dominates = di.properlyDominates(opr, &*sr.begin);
259               return remapped || dominates;
260             })) {
261           // Safe to parallelize operations which have all operands available in
262           // the root parallel block can be executed there.
263           parallelBuilder.clone(op, rootMapping);
264         } else {
265           // If any operand was not available, it means that there was no
266           // transitive use of a non-safe-to-parallelize operation outside `sr`.
267           // This means that there should be no transitive uses outside `sr` of
268           // `op`.
269           assert(llvm::all_of(op.getResults(), [&](Value v) {
270             return !isTransitivelyUsedOutside(v, sr);
271           }));
272           allParallelized = false;
273         }
274       } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
275         auto hoisted =
276             cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
277         rootMapping.map(&*alloca, &*hoisted);
278         rootMapping.map(alloca.getResult(), hoisted.getResult());
279         copyPrivate.push_back(hoisted);
280         allParallelized = false;
281       } else {
282         singleBuilder.clone(op, singleMapping);
283         // Prepare reloaded values for results of operations that cannot be
284         // safely parallelized and which are used after the region `sr`.
285         for (auto res : op.getResults()) {
286           if (isTransitivelyUsedOutside(res, sr)) {
287             auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
288                                           parallelBuilder, singleMapping);
289             if (alloc)
290               copyPrivate.push_back(alloc);
291           }
292         }
293         allParallelized = false;
294       }
295     }
296     singleBuilder.create<omp::TerminatorOp>(loc);
297     return {allParallelized, copyPrivate};
298   };
299 
300   for (Block &block : sourceRegion) {
301     Block *targetBlock = rootBuilder.createBlock(
302         &targetRegion, {}, block.getArgumentTypes(),
303         llvm::map_to_vector(block.getArguments(),
304                             [](BlockArgument arg) { return arg.getLoc(); }));
305     rootMapping.map(&block, targetBlock);
306     rootMapping.map(block.getArguments(), targetBlock->getArguments());
307   }
308 
309   auto handleOneBlock = [&](Block &block) {
310     Block &targetBlock = *rootMapping.lookup(&block);
311     rootBuilder.setInsertionPointToStart(&targetBlock);
312     Operation *terminator = block.getTerminator();
313     SmallVector<std::variant<SingleRegion, Operation *>> regions;
314 
315     auto it = block.begin();
316     auto getOneRegion = [&]() {
317       if (&*it == terminator)
318         return false;
319       if (mustParallelizeOp(&*it)) {
320         regions.push_back(&*it);
321         it++;
322         return true;
323       }
324       SingleRegion sr;
325       sr.begin = it;
326       while (&*it != terminator && !mustParallelizeOp(&*it))
327         it++;
328       sr.end = it;
329       assert(sr.begin != sr.end);
330       regions.push_back(sr);
331       return true;
332     };
333     while (getOneRegion())
334       ;
335 
336     for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
337       bool isLast = i + 1 == regions.size();
338       if (std::holds_alternative<SingleRegion>(opOrSingle)) {
339         OpBuilder singleBuilder(sourceRegion.getContext());
340         Block *singleBlock = new Block();
341         singleBuilder.setInsertionPointToStart(singleBlock);
342 
343         OpBuilder allocaBuilder(sourceRegion.getContext());
344         Block *allocaBlock = new Block();
345         allocaBuilder.setInsertionPointToStart(allocaBlock);
346 
347         OpBuilder parallelBuilder(sourceRegion.getContext());
348         Block *parallelBlock = new Block();
349         parallelBuilder.setInsertionPointToStart(parallelBlock);
350 
351         auto [allParallelized, copyprivateVars] =
352             moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
353                          singleBuilder, parallelBuilder);
354         if (allParallelized) {
355           // The single region was not required as all operations were safe to
356           // parallelize
357           assert(copyprivateVars.empty());
358           assert(allocaBlock->empty());
359           delete singleBlock;
360         } else {
361           omp::SingleOperands singleOperands;
362           if (isLast)
363             singleOperands.nowait = rootBuilder.getUnitAttr();
364           singleOperands.copyprivateVars = copyprivateVars;
365           cleanupBlock(singleBlock);
366           for (auto var : singleOperands.copyprivateVars) {
367             mlir::func::FuncOp funcOp =
368                 createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
369             singleOperands.copyprivateSyms.push_back(
370                 SymbolRefAttr::get(funcOp));
371           }
372           omp::SingleOp singleOp =
373               rootBuilder.create<omp::SingleOp>(loc, singleOperands);
374           singleOp.getRegion().push_back(singleBlock);
375           targetRegion.front().getOperations().splice(
376               singleOp->getIterator(), allocaBlock->getOperations());
377         }
378         rootBuilder.getInsertionBlock()->getOperations().splice(
379             rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
380         delete allocaBlock;
381         delete parallelBlock;
382       } else {
383         auto op = std::get<Operation *>(opOrSingle);
384         if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
385           omp::WsloopOperands wsloopOperands;
386           if (isLast)
387             wsloopOperands.nowait = rootBuilder.getUnitAttr();
388           auto wsloop =
389               rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
390           auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
391               rootBuilder.clone(*wslw, rootMapping));
392           wsloop.getRegion().takeBody(clonedWslw.getRegion());
393           clonedWslw->erase();
394         } else {
395           assert(mustParallelizeOp(op));
396           Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
397           for (auto [region, clonedRegion] :
398                llvm::zip(op->getRegions(), cloned->getRegions()))
399             parallelizeRegion(region, clonedRegion, rootMapping, loc, di);
400         }
401       }
402     }
403 
404     rootBuilder.clone(*block.getTerminator(), rootMapping);
405   };
406 
407   if (sourceRegion.hasOneBlock()) {
408     handleOneBlock(sourceRegion.front());
409   } else if (!sourceRegion.empty()) {
410     auto &domTree = di.getDomTree(&sourceRegion);
411     for (auto node : llvm::breadth_first(domTree.getRootNode())) {
412       handleOneBlock(*node->getBlock());
413     }
414   }
415 
416   for (Block &targetBlock : targetRegion)
417     cleanupBlock(&targetBlock);
418 }
419 
420 /// Lowers workshare to a sequence of single-thread regions and parallel loops
421 ///
422 /// For example:
423 ///
424 /// omp.workshare {
425 ///   %a = fir.allocmem
426 ///   omp.workshare.loop_wrapper {}
427 ///   fir.call Assign %b %a
428 ///   fir.freemem %a
429 /// }
430 ///
431 /// becomes
432 ///
433 /// %tmp = fir.alloca
434 /// omp.single copyprivate(%tmp) {
435 ///   %a = fir.allocmem
436 ///   fir.store %a %tmp
437 /// }
438 /// %a_reloaded = fir.load %tmp
439 /// omp.workshare.loop_wrapper {}
440 /// omp.single {
441 ///   fir.call Assign %b %a_reloaded
442 ///   fir.freemem %a_reloaded
443 /// }
444 ///
445 /// Note that we allocate temporary memory for values in omp.single's which need
446 /// to be accessed by all threads and broadcast them using single's copyprivate
447 LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
448   Location loc = wsOp->getLoc();
449   IRMapping rootMapping;
450 
451   OpBuilder rootBuilder(wsOp);
452 
453   // FIXME Currently, we only support workshare constructs with structured
454   // control flow. The transformation itself supports CFG, however, once we
455   // transform the MLIR region in the omp.workshare, we need to inline that
456   // region in the parent block. We have no guarantees at this point of the
457   // pipeline that the parent op supports CFG (e.g. fir.if), thus this is not
458   // generally possible.  The alternative is to put the lowered region in an
459   // operation akin to scf.execute_region, which will get lowered at the same
460   // time when fir ops get lowered to CFG. However, SCF is not registered in
461   // flang so we cannot use it. Remove this requirement once we have
462   // scf.execute_region or an alternative operation available.
463   if (wsOp.getRegion().getBlocks().size() == 1) {
464     // This operation is just a placeholder which will be erased later. We need
465     // it because our `parallelizeRegion` function works on regions and not
466     // blocks.
467     omp::WorkshareOp newOp =
468         rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
469     if (!wsOp.getNowait())
470       rootBuilder.create<omp::BarrierOp>(loc);
471 
472     parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc,
473                       di);
474 
475     // Inline the contents of the placeholder workshare op into its parent
476     // block.
477     Block *theBlock = &newOp.getRegion().front();
478     Operation *term = theBlock->getTerminator();
479     Block *parentBlock = wsOp->getBlock();
480     parentBlock->getOperations().splice(newOp->getIterator(),
481                                         theBlock->getOperations());
482     assert(term->getNumOperands() == 0);
483     term->erase();
484     newOp->erase();
485     wsOp->erase();
486   } else {
487     // Otherwise just change the operation to an omp.single.
488 
489     wsOp->emitWarning(
490         "omp workshare with unstructured control flow is currently "
491         "unsupported and will be serialized.");
492 
493     // `shouldUseWorkshareLowering` should have guaranteed that there are no
494     // omp.workshare_loop_wrapper's that bind to this omp.workshare.
495     assert(!wsOp->walk([&](Operation *op) {
496                   // Nested omp.workshare can have their own
497                   // omp.workshare_loop_wrapper's.
498                   if (isa<omp::WorkshareOp>(op))
499                     return WalkResult::skip();
500                   if (isa<omp::WorkshareLoopWrapperOp>(op))
501                     return WalkResult::interrupt();
502                   return WalkResult::advance();
503                 })
504                 .wasInterrupted());
505 
506     omp::SingleOperands operands;
507     operands.nowait = wsOp.getNowaitAttr();
508     omp::SingleOp newOp = rootBuilder.create<omp::SingleOp>(loc, operands);
509 
510     newOp.getRegion().getBlocks().splice(newOp.getRegion().getBlocks().begin(),
511                                          wsOp.getRegion().getBlocks());
512     wsOp->erase();
513   }
514   return success();
515 }
516 
517 class LowerWorksharePass
518     : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
519 public:
520   void runOnOperation() override {
521     mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>();
522     getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
523       if (failed(lowerWorkshare(wsOp, di)))
524         signalPassFailure();
525     });
526   }
527 };
528 } // namespace
529