xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp (revision cca32174fef004aadc177fcde44904e326c639fb)
1a70aa7bbSRiver Riddle //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle 
9039b969bSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
1067d0d7acSMichele Scuttari 
11*cca32174SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
1267d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/IR/SCF.h"
13f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
14a70aa7bbSRiver Riddle #include "mlir/Transforms/RegionUtils.h"
15981932bcSTres Popp #include "llvm/ADT/SmallSet.h"
16a70aa7bbSRiver Riddle #include "llvm/Support/CommandLine.h"
17a70aa7bbSRiver Riddle #include "llvm/Support/Debug.h"
18a70aa7bbSRiver Riddle 
1967d0d7acSMichele Scuttari namespace mlir {
20981932bcSTres Popp #define GEN_PASS_DEF_TESTSCFPARALLELLOOPCOLLAPSING
2167d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
2267d0d7acSMichele Scuttari } // namespace mlir
2367d0d7acSMichele Scuttari 
24a70aa7bbSRiver Riddle #define DEBUG_TYPE "parallel-loop-collapsing"
25a70aa7bbSRiver Riddle 
26a70aa7bbSRiver Riddle using namespace mlir;
27a70aa7bbSRiver Riddle 
28a70aa7bbSRiver Riddle namespace {
29981932bcSTres Popp struct TestSCFParallelLoopCollapsing
30981932bcSTres Popp     : public impl::TestSCFParallelLoopCollapsingBase<
31981932bcSTres Popp           TestSCFParallelLoopCollapsing> {
325aeb604cSMaheshRavishankar 
33a70aa7bbSRiver Riddle   void runOnOperation() override {
34a70aa7bbSRiver Riddle     Operation *module = getOperation();
35a70aa7bbSRiver Riddle 
36a70aa7bbSRiver Riddle     // The common case for GPU dialect will be simplifying the ParallelOp to 3
37a70aa7bbSRiver Riddle     // arguments, so we do that here to simplify things.
38a70aa7bbSRiver Riddle     llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops;
39981932bcSTres Popp 
40981932bcSTres Popp     // Gather the input args into the format required by
41981932bcSTres Popp     // `collapseParallelLoops`.
42a70aa7bbSRiver Riddle     if (!clCollapsedIndices0.empty())
43a70aa7bbSRiver Riddle       combinedLoops.push_back(clCollapsedIndices0);
44981932bcSTres Popp     if (!clCollapsedIndices1.empty()) {
45981932bcSTres Popp       if (clCollapsedIndices0.empty()) {
46981932bcSTres Popp         llvm::errs()
47981932bcSTres Popp             << "collapsed-indices-1 specified but not collapsed-indices-0";
48981932bcSTres Popp         signalPassFailure();
49981932bcSTres Popp         return;
50981932bcSTres Popp       }
51a70aa7bbSRiver Riddle       combinedLoops.push_back(clCollapsedIndices1);
52981932bcSTres Popp     }
53981932bcSTres Popp     if (!clCollapsedIndices2.empty()) {
54981932bcSTres Popp       if (clCollapsedIndices1.empty()) {
55981932bcSTres Popp         llvm::errs()
56981932bcSTres Popp             << "collapsed-indices-2 specified but not collapsed-indices-1";
57981932bcSTres Popp         signalPassFailure();
58981932bcSTres Popp         return;
59981932bcSTres Popp       }
60a70aa7bbSRiver Riddle       combinedLoops.push_back(clCollapsedIndices2);
61981932bcSTres Popp     }
62981932bcSTres Popp 
63981932bcSTres Popp     if (combinedLoops.empty()) {
64981932bcSTres Popp       llvm::errs() << "No collapsed-indices were specified. This pass is only "
65981932bcSTres Popp                       "for testing and does not automatically collapse all "
66981932bcSTres Popp                       "parallel loops or similar.";
67981932bcSTres Popp       signalPassFailure();
68981932bcSTres Popp       return;
69981932bcSTres Popp     }
70981932bcSTres Popp 
71981932bcSTres Popp     // Confirm that the specified loops are [0,N) by testing that N values exist
72981932bcSTres Popp     // with the maximum value being N-1.
73981932bcSTres Popp     llvm::SmallSet<unsigned, 8> flattenedCombinedLoops;
74981932bcSTres Popp     unsigned maxCollapsedIndex = 0;
75981932bcSTres Popp     for (auto &loops : combinedLoops) {
76981932bcSTres Popp       for (auto &loop : loops) {
77981932bcSTres Popp         flattenedCombinedLoops.insert(loop);
78981932bcSTres Popp         maxCollapsedIndex = std::max(maxCollapsedIndex, loop);
79981932bcSTres Popp       }
80981932bcSTres Popp     }
81981932bcSTres Popp 
82981932bcSTres Popp     if (maxCollapsedIndex != flattenedCombinedLoops.size() - 1 ||
83981932bcSTres Popp         !flattenedCombinedLoops.contains(maxCollapsedIndex)) {
84981932bcSTres Popp       llvm::errs()
85981932bcSTres Popp           << "collapsed-indices arguments must include all values [0,N).";
86981932bcSTres Popp       signalPassFailure();
87981932bcSTres Popp       return;
88981932bcSTres Popp     }
89981932bcSTres Popp 
90981932bcSTres Popp     // Only apply the transformation on parallel loops where the specified
91981932bcSTres Popp     // transformation is valid, but do NOT early abort in the case of invalid
92981932bcSTres Popp     // loops.
935aeb604cSMaheshRavishankar     IRRewriter rewriter(&getContext());
94981932bcSTres Popp     module->walk([&](scf::ParallelOp op) {
95981932bcSTres Popp       if (flattenedCombinedLoops.size() != op.getNumLoops()) {
96981932bcSTres Popp         op.emitOpError("has ")
97981932bcSTres Popp             << op.getNumLoops()
98981932bcSTres Popp             << " iter args while this limited functionality testing pass was "
99981932bcSTres Popp                "configured only for loops with exactly "
100981932bcSTres Popp             << flattenedCombinedLoops.size() << " iter args.";
101981932bcSTres Popp         return;
102981932bcSTres Popp       }
1035aeb604cSMaheshRavishankar       collapseParallelLoops(rewriter, op, combinedLoops);
104a70aa7bbSRiver Riddle     });
105a70aa7bbSRiver Riddle   }
106a70aa7bbSRiver Riddle };
107a70aa7bbSRiver Riddle } // namespace
108039b969bSMichele Scuttari 
109981932bcSTres Popp std::unique_ptr<Pass> mlir::createTestSCFParallelLoopCollapsingPass() {
110981932bcSTres Popp   return std::make_unique<TestSCFParallelLoopCollapsing>();
111039b969bSMichele Scuttari }
112