xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/SCF/IR/SCF.h"
11 #include "mlir/Dialect/SCF/Transforms/Passes.h"
12 #include "mlir/Dialect/SCF/Utils/Utils.h"
13 #include "mlir/Transforms/RegionUtils.h"
14 #include "llvm/Support/CommandLine.h"
15 #include "llvm/Support/Debug.h"
16 
17 #define DEBUG_TYPE "parallel-loop-collapsing"
18 
19 using namespace mlir;
20 
21 namespace {
22 struct ParallelLoopCollapsing
23     : public SCFParallelLoopCollapsingBase<ParallelLoopCollapsing> {
24   void runOnOperation() override {
25     Operation *module = getOperation();
26 
27     module->walk([&](scf::ParallelOp op) {
28       // The common case for GPU dialect will be simplifying the ParallelOp to 3
29       // arguments, so we do that here to simplify things.
30       llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops;
31       if (!clCollapsedIndices0.empty())
32         combinedLoops.push_back(clCollapsedIndices0);
33       if (!clCollapsedIndices1.empty())
34         combinedLoops.push_back(clCollapsedIndices1);
35       if (!clCollapsedIndices2.empty())
36         combinedLoops.push_back(clCollapsedIndices2);
37       collapseParallelLoops(op, combinedLoops);
38     });
39   }
40 };
41 } // namespace
42 
43 std::unique_ptr<Pass> mlir::createParallelLoopCollapsingPass() {
44   return std::make_unique<ParallelLoopCollapsing>();
45 }
46