1a70aa7bbSRiver Riddle //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===// 2a70aa7bbSRiver Riddle // 3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a70aa7bbSRiver Riddle // 7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===// 8a70aa7bbSRiver Riddle 9039b969bSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h" 1067d0d7acSMichele Scuttari 11*cca32174SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 1267d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/IR/SCF.h" 13f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h" 14a70aa7bbSRiver Riddle #include "mlir/Transforms/RegionUtils.h" 15981932bcSTres Popp #include "llvm/ADT/SmallSet.h" 16a70aa7bbSRiver Riddle #include "llvm/Support/CommandLine.h" 17a70aa7bbSRiver Riddle #include "llvm/Support/Debug.h" 18a70aa7bbSRiver Riddle 1967d0d7acSMichele Scuttari namespace mlir { 20981932bcSTres Popp #define GEN_PASS_DEF_TESTSCFPARALLELLOOPCOLLAPSING 2167d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 2267d0d7acSMichele Scuttari } // namespace mlir 2367d0d7acSMichele Scuttari 24a70aa7bbSRiver Riddle #define DEBUG_TYPE "parallel-loop-collapsing" 25a70aa7bbSRiver Riddle 26a70aa7bbSRiver Riddle using namespace mlir; 27a70aa7bbSRiver Riddle 28a70aa7bbSRiver Riddle namespace { 29981932bcSTres Popp struct TestSCFParallelLoopCollapsing 30981932bcSTres Popp : public impl::TestSCFParallelLoopCollapsingBase< 31981932bcSTres Popp TestSCFParallelLoopCollapsing> { 325aeb604cSMaheshRavishankar 33a70aa7bbSRiver Riddle void runOnOperation() override { 34a70aa7bbSRiver Riddle Operation *module = getOperation(); 35a70aa7bbSRiver Riddle 36a70aa7bbSRiver Riddle // The common case for GPU dialect will be simplifying the ParallelOp to 3 37a70aa7bbSRiver Riddle // arguments, so we do that here to simplify things. 38a70aa7bbSRiver Riddle llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops; 39981932bcSTres Popp 40981932bcSTres Popp // Gather the input args into the format required by 41981932bcSTres Popp // `collapseParallelLoops`. 42a70aa7bbSRiver Riddle if (!clCollapsedIndices0.empty()) 43a70aa7bbSRiver Riddle combinedLoops.push_back(clCollapsedIndices0); 44981932bcSTres Popp if (!clCollapsedIndices1.empty()) { 45981932bcSTres Popp if (clCollapsedIndices0.empty()) { 46981932bcSTres Popp llvm::errs() 47981932bcSTres Popp << "collapsed-indices-1 specified but not collapsed-indices-0"; 48981932bcSTres Popp signalPassFailure(); 49981932bcSTres Popp return; 50981932bcSTres Popp } 51a70aa7bbSRiver Riddle combinedLoops.push_back(clCollapsedIndices1); 52981932bcSTres Popp } 53981932bcSTres Popp if (!clCollapsedIndices2.empty()) { 54981932bcSTres Popp if (clCollapsedIndices1.empty()) { 55981932bcSTres Popp llvm::errs() 56981932bcSTres Popp << "collapsed-indices-2 specified but not collapsed-indices-1"; 57981932bcSTres Popp signalPassFailure(); 58981932bcSTres Popp return; 59981932bcSTres Popp } 60a70aa7bbSRiver Riddle combinedLoops.push_back(clCollapsedIndices2); 61981932bcSTres Popp } 62981932bcSTres Popp 63981932bcSTres Popp if (combinedLoops.empty()) { 64981932bcSTres Popp llvm::errs() << "No collapsed-indices were specified. This pass is only " 65981932bcSTres Popp "for testing and does not automatically collapse all " 66981932bcSTres Popp "parallel loops or similar."; 67981932bcSTres Popp signalPassFailure(); 68981932bcSTres Popp return; 69981932bcSTres Popp } 70981932bcSTres Popp 71981932bcSTres Popp // Confirm that the specified loops are [0,N) by testing that N values exist 72981932bcSTres Popp // with the maximum value being N-1. 73981932bcSTres Popp llvm::SmallSet<unsigned, 8> flattenedCombinedLoops; 74981932bcSTres Popp unsigned maxCollapsedIndex = 0; 75981932bcSTres Popp for (auto &loops : combinedLoops) { 76981932bcSTres Popp for (auto &loop : loops) { 77981932bcSTres Popp flattenedCombinedLoops.insert(loop); 78981932bcSTres Popp maxCollapsedIndex = std::max(maxCollapsedIndex, loop); 79981932bcSTres Popp } 80981932bcSTres Popp } 81981932bcSTres Popp 82981932bcSTres Popp if (maxCollapsedIndex != flattenedCombinedLoops.size() - 1 || 83981932bcSTres Popp !flattenedCombinedLoops.contains(maxCollapsedIndex)) { 84981932bcSTres Popp llvm::errs() 85981932bcSTres Popp << "collapsed-indices arguments must include all values [0,N)."; 86981932bcSTres Popp signalPassFailure(); 87981932bcSTres Popp return; 88981932bcSTres Popp } 89981932bcSTres Popp 90981932bcSTres Popp // Only apply the transformation on parallel loops where the specified 91981932bcSTres Popp // transformation is valid, but do NOT early abort in the case of invalid 92981932bcSTres Popp // loops. 935aeb604cSMaheshRavishankar IRRewriter rewriter(&getContext()); 94981932bcSTres Popp module->walk([&](scf::ParallelOp op) { 95981932bcSTres Popp if (flattenedCombinedLoops.size() != op.getNumLoops()) { 96981932bcSTres Popp op.emitOpError("has ") 97981932bcSTres Popp << op.getNumLoops() 98981932bcSTres Popp << " iter args while this limited functionality testing pass was " 99981932bcSTres Popp "configured only for loops with exactly " 100981932bcSTres Popp << flattenedCombinedLoops.size() << " iter args."; 101981932bcSTres Popp return; 102981932bcSTres Popp } 1035aeb604cSMaheshRavishankar collapseParallelLoops(rewriter, op, combinedLoops); 104a70aa7bbSRiver Riddle }); 105a70aa7bbSRiver Riddle } 106a70aa7bbSRiver Riddle }; 107a70aa7bbSRiver Riddle } // namespace 108039b969bSMichele Scuttari 109981932bcSTres Popp std::unique_ptr<Pass> mlir::createTestSCFParallelLoopCollapsingPass() { 110981932bcSTres Popp return std::make_unique<TestSCFParallelLoopCollapsing>(); 111039b969bSMichele Scuttari } 112