xref: /llvm-project/mlir/include/mlir/IR/Threading.h (revision 6594f428de91e333c1cbea4f55e79b18d31024c4)
1 //===- Threading.h - MLIR Threading Utilities -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various utilies for multithreaded processing within MLIR.
10 // These utilities automatically handle many of the necessary threading
11 // conditions, such as properly ordering diagnostics, observing if threading is
12 // disabled, etc. These utilities should be used over other threading utilities
13 // whenever feasible.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_IR_THREADING_H
18 #define MLIR_IR_THREADING_H
19 
20 #include "mlir/IR/Diagnostics.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/Support/ThreadPool.h"
23 #include <atomic>
24 
25 namespace mlir {
26 
27 /// Invoke the given function on the elements between [begin, end)
28 /// asynchronously. If the given function returns a failure when processing any
29 /// of the elements, execution is stopped and a failure is returned from this
30 /// function. This means that in the case of failure, not all elements of the
31 /// range will be processed. Diagnostics emitted during processing are ordered
32 /// relative to the element's position within [begin, end). If the provided
33 /// context does not have multi-threading enabled, this function always
34 /// processes elements sequentially.
35 template <typename IteratorT, typename FuncT>
failableParallelForEach(MLIRContext * context,IteratorT begin,IteratorT end,FuncT && func)36 LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
37                                       IteratorT end, FuncT &&func) {
38   unsigned numElements = static_cast<unsigned>(std::distance(begin, end));
39   if (numElements == 0)
40     return success();
41 
42   // If multithreading is disabled or there is a small number of elements,
43   // process the elements directly on this thread.
44   if (!context->isMultithreadingEnabled() || numElements <= 1) {
45     for (; begin != end; ++begin)
46       if (failed(func(*begin)))
47         return failure();
48     return success();
49   }
50 
51   // Build a wrapper processing function that properly initializes a parallel
52   // diagnostic handler.
53   ParallelDiagnosticHandler handler(context);
54   std::atomic<unsigned> curIndex(0);
55   std::atomic<bool> processingFailed(false);
56   auto processFn = [&] {
57     while (!processingFailed) {
58       unsigned index = curIndex++;
59       if (index >= numElements)
60         break;
61       handler.setOrderIDForThread(index);
62       if (failed(func(*std::next(begin, index))))
63         processingFailed = true;
64       handler.eraseOrderIDForThread();
65     }
66   };
67 
68   // Otherwise, process the elements in parallel.
69   llvm::ThreadPoolInterface &threadPool = context->getThreadPool();
70   llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
71   size_t numActions = std::min(numElements, threadPool.getMaxConcurrency());
72   for (unsigned i = 0; i < numActions; ++i)
73     tasksGroup.async(processFn);
74   // If the current thread is a worker thread from the pool, then waiting for
75   // the task group allows the current thread to also participate in processing
76   // tasks from the group, which avoid any deadlock/starvation.
77   tasksGroup.wait();
78   return failure(processingFailed);
79 }
80 
81 /// Invoke the given function on the elements in the provided range
82 /// asynchronously. If the given function returns a failure when processing any
83 /// of the elements, execution is stopped and a failure is returned from this
84 /// function. This means that in the case of failure, not all elements of the
85 /// range will be processed. Diagnostics emitted during processing are ordered
86 /// relative to the element's position within the range. If the provided context
87 /// does not have multi-threading enabled, this function always processes
88 /// elements sequentially.
89 template <typename RangeT, typename FuncT>
failableParallelForEach(MLIRContext * context,RangeT && range,FuncT && func)90 LogicalResult failableParallelForEach(MLIRContext *context, RangeT &&range,
91                                       FuncT &&func) {
92   return failableParallelForEach(context, std::begin(range), std::end(range),
93                                  std::forward<FuncT>(func));
94 }
95 
96 /// Invoke the given function on the elements between [begin, end)
97 /// asynchronously. If the given function returns a failure when processing any
98 /// of the elements, execution is stopped and a failure is returned from this
99 /// function. This means that in the case of failure, not all elements of the
100 /// range will be processed. Diagnostics emitted during processing are ordered
101 /// relative to the element's position within [begin, end). If the provided
102 /// context does not have multi-threading enabled, this function always
103 /// processes elements sequentially.
104 template <typename FuncT>
failableParallelForEachN(MLIRContext * context,size_t begin,size_t end,FuncT && func)105 LogicalResult failableParallelForEachN(MLIRContext *context, size_t begin,
106                                        size_t end, FuncT &&func) {
107   return failableParallelForEach(context, llvm::seq(begin, end),
108                                  std::forward<FuncT>(func));
109 }
110 
111 /// Invoke the given function on the elements between [begin, end)
112 /// asynchronously. Diagnostics emitted during processing are ordered relative
113 /// to the element's position within [begin, end). If the provided context does
114 /// not have multi-threading enabled, this function always processes elements
115 /// sequentially.
116 template <typename IteratorT, typename FuncT>
parallelForEach(MLIRContext * context,IteratorT begin,IteratorT end,FuncT && func)117 void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end,
118                      FuncT &&func) {
119   (void)failableParallelForEach(context, begin, end, [&](auto &&value) {
120     return func(std::forward<decltype(value)>(value)), success();
121   });
122 }
123 
124 /// Invoke the given function on the elements in the provided range
125 /// asynchronously. Diagnostics emitted during processing are ordered relative
126 /// to the element's position within the range. If the provided context does not
127 /// have multi-threading enabled, this function always processes elements
128 /// sequentially.
129 template <typename RangeT, typename FuncT>
parallelForEach(MLIRContext * context,RangeT && range,FuncT && func)130 void parallelForEach(MLIRContext *context, RangeT &&range, FuncT &&func) {
131   parallelForEach(context, std::begin(range), std::end(range),
132                   std::forward<FuncT>(func));
133 }
134 
135 /// Invoke the given function on the elements between [begin, end)
136 /// asynchronously. Diagnostics emitted during processing are ordered relative
137 /// to the element's position within [begin, end). If the provided context does
138 /// not have multi-threading enabled, this function always processes elements
139 /// sequentially.
140 template <typename FuncT>
parallelFor(MLIRContext * context,size_t begin,size_t end,FuncT && func)141 void parallelFor(MLIRContext *context, size_t begin, size_t end, FuncT &&func) {
142   parallelForEach(context, llvm::seq(begin, end), std::forward<FuncT>(func));
143 }
144 
145 } // namespace mlir
146 
147 #endif // MLIR_IR_THREADING_H
148