xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp (revision 5aeb604c7ce417eea110f9803a6c5cb1cdbc5372)
1 //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/SCF/IR/SCF.h"
12 #include "mlir/Dialect/SCF/Utils/Utils.h"
13 #include "mlir/Transforms/RegionUtils.h"
14 #include "llvm/ADT/SmallSet.h"
15 #include "llvm/Support/CommandLine.h"
16 #include "llvm/Support/Debug.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_TESTSCFPARALLELLOOPCOLLAPSING
20 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
21 } // namespace mlir
22 
23 #define DEBUG_TYPE "parallel-loop-collapsing"
24 
25 using namespace mlir;
26 
27 namespace {
28 struct TestSCFParallelLoopCollapsing
29     : public impl::TestSCFParallelLoopCollapsingBase<
30           TestSCFParallelLoopCollapsing> {
31 
32   void runOnOperation() override {
33     Operation *module = getOperation();
34 
35     // The common case for GPU dialect will be simplifying the ParallelOp to 3
36     // arguments, so we do that here to simplify things.
37     llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops;
38 
39     // Gather the input args into the format required by
40     // `collapseParallelLoops`.
41     if (!clCollapsedIndices0.empty())
42       combinedLoops.push_back(clCollapsedIndices0);
43     if (!clCollapsedIndices1.empty()) {
44       if (clCollapsedIndices0.empty()) {
45         llvm::errs()
46             << "collapsed-indices-1 specified but not collapsed-indices-0";
47         signalPassFailure();
48         return;
49       }
50       combinedLoops.push_back(clCollapsedIndices1);
51     }
52     if (!clCollapsedIndices2.empty()) {
53       if (clCollapsedIndices1.empty()) {
54         llvm::errs()
55             << "collapsed-indices-2 specified but not collapsed-indices-1";
56         signalPassFailure();
57         return;
58       }
59       combinedLoops.push_back(clCollapsedIndices2);
60     }
61 
62     if (combinedLoops.empty()) {
63       llvm::errs() << "No collapsed-indices were specified. This pass is only "
64                       "for testing and does not automatically collapse all "
65                       "parallel loops or similar.";
66       signalPassFailure();
67       return;
68     }
69 
70     // Confirm that the specified loops are [0,N) by testing that N values exist
71     // with the maximum value being N-1.
72     llvm::SmallSet<unsigned, 8> flattenedCombinedLoops;
73     unsigned maxCollapsedIndex = 0;
74     for (auto &loops : combinedLoops) {
75       for (auto &loop : loops) {
76         flattenedCombinedLoops.insert(loop);
77         maxCollapsedIndex = std::max(maxCollapsedIndex, loop);
78       }
79     }
80 
81     if (maxCollapsedIndex != flattenedCombinedLoops.size() - 1 ||
82         !flattenedCombinedLoops.contains(maxCollapsedIndex)) {
83       llvm::errs()
84           << "collapsed-indices arguments must include all values [0,N).";
85       signalPassFailure();
86       return;
87     }
88 
89     // Only apply the transformation on parallel loops where the specified
90     // transformation is valid, but do NOT early abort in the case of invalid
91     // loops.
92     IRRewriter rewriter(&getContext());
93     module->walk([&](scf::ParallelOp op) {
94       if (flattenedCombinedLoops.size() != op.getNumLoops()) {
95         op.emitOpError("has ")
96             << op.getNumLoops()
97             << " iter args while this limited functionality testing pass was "
98                "configured only for loops with exactly "
99             << flattenedCombinedLoops.size() << " iter args.";
100         return;
101       }
102       collapseParallelLoops(rewriter, op, combinedLoops);
103     });
104   }
105 };
106 } // namespace
107 
108 std::unique_ptr<Pass> mlir::createTestSCFParallelLoopCollapsingPass() {
109   return std::make_unique<TestSCFParallelLoopCollapsing>();
110 }
111