xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp (revision cca32174fef004aadc177fcde44904e326c639fb)
1 //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/SCF/Utils/Utils.h"
14 #include "mlir/Transforms/RegionUtils.h"
15 #include "llvm/ADT/SmallSet.h"
16 #include "llvm/Support/CommandLine.h"
17 #include "llvm/Support/Debug.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_TESTSCFPARALLELLOOPCOLLAPSING
21 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22 } // namespace mlir
23 
24 #define DEBUG_TYPE "parallel-loop-collapsing"
25 
26 using namespace mlir;
27 
28 namespace {
29 struct TestSCFParallelLoopCollapsing
30     : public impl::TestSCFParallelLoopCollapsingBase<
31           TestSCFParallelLoopCollapsing> {
32 
33   void runOnOperation() override {
34     Operation *module = getOperation();
35 
36     // The common case for GPU dialect will be simplifying the ParallelOp to 3
37     // arguments, so we do that here to simplify things.
38     llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops;
39 
40     // Gather the input args into the format required by
41     // `collapseParallelLoops`.
42     if (!clCollapsedIndices0.empty())
43       combinedLoops.push_back(clCollapsedIndices0);
44     if (!clCollapsedIndices1.empty()) {
45       if (clCollapsedIndices0.empty()) {
46         llvm::errs()
47             << "collapsed-indices-1 specified but not collapsed-indices-0";
48         signalPassFailure();
49         return;
50       }
51       combinedLoops.push_back(clCollapsedIndices1);
52     }
53     if (!clCollapsedIndices2.empty()) {
54       if (clCollapsedIndices1.empty()) {
55         llvm::errs()
56             << "collapsed-indices-2 specified but not collapsed-indices-1";
57         signalPassFailure();
58         return;
59       }
60       combinedLoops.push_back(clCollapsedIndices2);
61     }
62 
63     if (combinedLoops.empty()) {
64       llvm::errs() << "No collapsed-indices were specified. This pass is only "
65                       "for testing and does not automatically collapse all "
66                       "parallel loops or similar.";
67       signalPassFailure();
68       return;
69     }
70 
71     // Confirm that the specified loops are [0,N) by testing that N values exist
72     // with the maximum value being N-1.
73     llvm::SmallSet<unsigned, 8> flattenedCombinedLoops;
74     unsigned maxCollapsedIndex = 0;
75     for (auto &loops : combinedLoops) {
76       for (auto &loop : loops) {
77         flattenedCombinedLoops.insert(loop);
78         maxCollapsedIndex = std::max(maxCollapsedIndex, loop);
79       }
80     }
81 
82     if (maxCollapsedIndex != flattenedCombinedLoops.size() - 1 ||
83         !flattenedCombinedLoops.contains(maxCollapsedIndex)) {
84       llvm::errs()
85           << "collapsed-indices arguments must include all values [0,N).";
86       signalPassFailure();
87       return;
88     }
89 
90     // Only apply the transformation on parallel loops where the specified
91     // transformation is valid, but do NOT early abort in the case of invalid
92     // loops.
93     IRRewriter rewriter(&getContext());
94     module->walk([&](scf::ParallelOp op) {
95       if (flattenedCombinedLoops.size() != op.getNumLoops()) {
96         op.emitOpError("has ")
97             << op.getNumLoops()
98             << " iter args while this limited functionality testing pass was "
99                "configured only for loops with exactly "
100             << flattenedCombinedLoops.size() << " iter args.";
101         return;
102       }
103       collapseParallelLoops(rewriter, op, combinedLoops);
104     });
105   }
106 };
107 } // namespace
108 
109 std::unique_ptr<Pass> mlir::createTestSCFParallelLoopCollapsingPass() {
110   return std::make_unique<TestSCFParallelLoopCollapsing>();
111 }
112