xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp (revision 2be8af8f0e0780901213b6fd3013a5268ddc3359)
1 //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/SCF/IR/SCF.h"
12 #include "mlir/Dialect/SCF/Utils/Utils.h"
13 #include "mlir/Transforms/RegionUtils.h"
14 #include "llvm/Support/CommandLine.h"
15 #include "llvm/Support/Debug.h"
16 
17 namespace mlir {
18 #define GEN_PASS_DEF_SCFPARALLELLOOPCOLLAPSINGPASS
19 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
20 } // namespace mlir
21 
22 #define DEBUG_TYPE "parallel-loop-collapsing"
23 
24 using namespace mlir;
25 
26 namespace {
27 struct SCFParallelLoopCollapsingPass
28     : public impl::SCFParallelLoopCollapsingPassBase<
29           SCFParallelLoopCollapsingPass> {
30   using SCFParallelLoopCollapsingPassBase::SCFParallelLoopCollapsingPassBase;
31 
32   void runOnOperation() override {
33     Operation *module = getOperation();
34 
35     module->walk([&](scf::ParallelOp op) {
36       // The common case for GPU dialect will be simplifying the ParallelOp to 3
37       // arguments, so we do that here to simplify things.
38       llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops;
39       if (!clCollapsedIndices0.empty())
40         combinedLoops.push_back(clCollapsedIndices0);
41       if (!clCollapsedIndices1.empty())
42         combinedLoops.push_back(clCollapsedIndices1);
43       if (!clCollapsedIndices2.empty())
44         combinedLoops.push_back(clCollapsedIndices2);
45       collapseParallelLoops(op, combinedLoops);
46     });
47   }
48 };
49 } // namespace
50