1 //===- Threading.h - MLIR Threading Utilities -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various utilies for multithreaded processing within MLIR.
10 // These utilities automatically handle many of the necessary threading
11 // conditions, such as properly ordering diagnostics, observing if threading is
12 // disabled, etc. These utilities should be used over other threading utilities
13 // whenever feasible.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #ifndef MLIR_IR_THREADING_H
18 #define MLIR_IR_THREADING_H
19
20 #include "mlir/IR/Diagnostics.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/Support/ThreadPool.h"
23 #include <atomic>
24
25 namespace mlir {
26
27 /// Invoke the given function on the elements between [begin, end)
28 /// asynchronously. If the given function returns a failure when processing any
29 /// of the elements, execution is stopped and a failure is returned from this
30 /// function. This means that in the case of failure, not all elements of the
31 /// range will be processed. Diagnostics emitted during processing are ordered
32 /// relative to the element's position within [begin, end). If the provided
33 /// context does not have multi-threading enabled, this function always
34 /// processes elements sequentially.
35 template <typename IteratorT, typename FuncT>
failableParallelForEach(MLIRContext * context,IteratorT begin,IteratorT end,FuncT && func)36 LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
37 IteratorT end, FuncT &&func) {
38 unsigned numElements = static_cast<unsigned>(std::distance(begin, end));
39 if (numElements == 0)
40 return success();
41
42 // If multithreading is disabled or there is a small number of elements,
43 // process the elements directly on this thread.
44 if (!context->isMultithreadingEnabled() || numElements <= 1) {
45 for (; begin != end; ++begin)
46 if (failed(func(*begin)))
47 return failure();
48 return success();
49 }
50
51 // Build a wrapper processing function that properly initializes a parallel
52 // diagnostic handler.
53 ParallelDiagnosticHandler handler(context);
54 std::atomic<unsigned> curIndex(0);
55 std::atomic<bool> processingFailed(false);
56 auto processFn = [&] {
57 while (!processingFailed) {
58 unsigned index = curIndex++;
59 if (index >= numElements)
60 break;
61 handler.setOrderIDForThread(index);
62 if (failed(func(*std::next(begin, index))))
63 processingFailed = true;
64 handler.eraseOrderIDForThread();
65 }
66 };
67
68 // Otherwise, process the elements in parallel.
69 llvm::ThreadPoolInterface &threadPool = context->getThreadPool();
70 llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
71 size_t numActions = std::min(numElements, threadPool.getMaxConcurrency());
72 for (unsigned i = 0; i < numActions; ++i)
73 tasksGroup.async(processFn);
74 // If the current thread is a worker thread from the pool, then waiting for
75 // the task group allows the current thread to also participate in processing
76 // tasks from the group, which avoid any deadlock/starvation.
77 tasksGroup.wait();
78 return failure(processingFailed);
79 }
80
81 /// Invoke the given function on the elements in the provided range
82 /// asynchronously. If the given function returns a failure when processing any
83 /// of the elements, execution is stopped and a failure is returned from this
84 /// function. This means that in the case of failure, not all elements of the
85 /// range will be processed. Diagnostics emitted during processing are ordered
86 /// relative to the element's position within the range. If the provided context
87 /// does not have multi-threading enabled, this function always processes
88 /// elements sequentially.
89 template <typename RangeT, typename FuncT>
failableParallelForEach(MLIRContext * context,RangeT && range,FuncT && func)90 LogicalResult failableParallelForEach(MLIRContext *context, RangeT &&range,
91 FuncT &&func) {
92 return failableParallelForEach(context, std::begin(range), std::end(range),
93 std::forward<FuncT>(func));
94 }
95
96 /// Invoke the given function on the elements between [begin, end)
97 /// asynchronously. If the given function returns a failure when processing any
98 /// of the elements, execution is stopped and a failure is returned from this
99 /// function. This means that in the case of failure, not all elements of the
100 /// range will be processed. Diagnostics emitted during processing are ordered
101 /// relative to the element's position within [begin, end). If the provided
102 /// context does not have multi-threading enabled, this function always
103 /// processes elements sequentially.
104 template <typename FuncT>
failableParallelForEachN(MLIRContext * context,size_t begin,size_t end,FuncT && func)105 LogicalResult failableParallelForEachN(MLIRContext *context, size_t begin,
106 size_t end, FuncT &&func) {
107 return failableParallelForEach(context, llvm::seq(begin, end),
108 std::forward<FuncT>(func));
109 }
110
111 /// Invoke the given function on the elements between [begin, end)
112 /// asynchronously. Diagnostics emitted during processing are ordered relative
113 /// to the element's position within [begin, end). If the provided context does
114 /// not have multi-threading enabled, this function always processes elements
115 /// sequentially.
116 template <typename IteratorT, typename FuncT>
parallelForEach(MLIRContext * context,IteratorT begin,IteratorT end,FuncT && func)117 void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end,
118 FuncT &&func) {
119 (void)failableParallelForEach(context, begin, end, [&](auto &&value) {
120 return func(std::forward<decltype(value)>(value)), success();
121 });
122 }
123
124 /// Invoke the given function on the elements in the provided range
125 /// asynchronously. Diagnostics emitted during processing are ordered relative
126 /// to the element's position within the range. If the provided context does not
127 /// have multi-threading enabled, this function always processes elements
128 /// sequentially.
129 template <typename RangeT, typename FuncT>
parallelForEach(MLIRContext * context,RangeT && range,FuncT && func)130 void parallelForEach(MLIRContext *context, RangeT &&range, FuncT &&func) {
131 parallelForEach(context, std::begin(range), std::end(range),
132 std::forward<FuncT>(func));
133 }
134
135 /// Invoke the given function on the elements between [begin, end)
136 /// asynchronously. Diagnostics emitted during processing are ordered relative
137 /// to the element's position within [begin, end). If the provided context does
138 /// not have multi-threading enabled, this function always processes elements
139 /// sequentially.
140 template <typename FuncT>
parallelFor(MLIRContext * context,size_t begin,size_t end,FuncT && func)141 void parallelFor(MLIRContext *context, size_t begin, size_t end, FuncT &&func) {
142 parallelForEach(context, llvm::seq(begin, end), std::forward<FuncT>(func));
143 }
144
145 } // namespace mlir
146
147 #endif // MLIR_IR_THREADING_H
148