1a70aa7bbSRiver Riddle //===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle //
9a70aa7bbSRiver Riddle // This file implements a pass to unroll loops by a specified unroll factor.
10a70aa7bbSRiver Riddle //
11a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
12a70aa7bbSRiver Riddle
13*abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
15f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
16a70aa7bbSRiver Riddle #include "mlir/IR/Builders.h"
17a70aa7bbSRiver Riddle #include "mlir/Pass/Pass.h"
18a70aa7bbSRiver Riddle
19a70aa7bbSRiver Riddle using namespace mlir;
20a70aa7bbSRiver Riddle
21a70aa7bbSRiver Riddle namespace {
22a70aa7bbSRiver Riddle
getNestingDepth(Operation * op)23a70aa7bbSRiver Riddle static unsigned getNestingDepth(Operation *op) {
24a70aa7bbSRiver Riddle Operation *currOp = op;
25a70aa7bbSRiver Riddle unsigned depth = 0;
26a70aa7bbSRiver Riddle while ((currOp = currOp->getParentOp())) {
27a70aa7bbSRiver Riddle if (isa<scf::ForOp>(currOp))
28a70aa7bbSRiver Riddle depth++;
29a70aa7bbSRiver Riddle }
30a70aa7bbSRiver Riddle return depth;
31a70aa7bbSRiver Riddle }
32a70aa7bbSRiver Riddle
335e50dd04SRiver Riddle struct TestLoopUnrollingPass
3487d6bf37SRiver Riddle : public PassWrapper<TestLoopUnrollingPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon11ffe58f0111::TestLoopUnrollingPass355e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopUnrollingPass)
365e50dd04SRiver Riddle
37a70aa7bbSRiver Riddle StringRef getArgument() const final { return "test-loop-unrolling"; }
getDescription__anon11ffe58f0111::TestLoopUnrollingPass38a70aa7bbSRiver Riddle StringRef getDescription() const final {
39a70aa7bbSRiver Riddle return "Tests loop unrolling transformation";
40a70aa7bbSRiver Riddle }
41a70aa7bbSRiver Riddle TestLoopUnrollingPass() = default;
TestLoopUnrollingPass__anon11ffe58f0111::TestLoopUnrollingPass42a70aa7bbSRiver Riddle TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
TestLoopUnrollingPass__anon11ffe58f0111::TestLoopUnrollingPass43a70aa7bbSRiver Riddle explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
44a70aa7bbSRiver Riddle unsigned loopDepthParam,
45a70aa7bbSRiver Riddle bool annotateLoopParam) {
46a70aa7bbSRiver Riddle unrollFactor = unrollFactorParam;
47a70aa7bbSRiver Riddle loopDepth = loopDepthParam;
48a70aa7bbSRiver Riddle annotateLoop = annotateLoopParam;
49a70aa7bbSRiver Riddle }
50a70aa7bbSRiver Riddle
getDependentDialects__anon11ffe58f0111::TestLoopUnrollingPass51a70aa7bbSRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override {
52*abc362a1SJakub Kuderski registry.insert<arith::ArithDialect>();
53a70aa7bbSRiver Riddle }
54a70aa7bbSRiver Riddle
runOnOperation__anon11ffe58f0111::TestLoopUnrollingPass55a70aa7bbSRiver Riddle void runOnOperation() override {
56a70aa7bbSRiver Riddle SmallVector<scf::ForOp, 4> loops;
5787d6bf37SRiver Riddle getOperation()->walk([&](scf::ForOp forOp) {
58a70aa7bbSRiver Riddle if (getNestingDepth(forOp) == loopDepth)
59a70aa7bbSRiver Riddle loops.push_back(forOp);
60a70aa7bbSRiver Riddle });
61a70aa7bbSRiver Riddle auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) {
62a70aa7bbSRiver Riddle if (annotateLoop) {
63a70aa7bbSRiver Riddle op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
64a70aa7bbSRiver Riddle }
65a70aa7bbSRiver Riddle };
66a70aa7bbSRiver Riddle for (auto loop : loops)
67a70aa7bbSRiver Riddle (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
68a70aa7bbSRiver Riddle }
69a70aa7bbSRiver Riddle Option<uint64_t> unrollFactor{*this, "unroll-factor",
70a70aa7bbSRiver Riddle llvm::cl::desc("Loop unroll factor."),
71a70aa7bbSRiver Riddle llvm::cl::init(1)};
72a70aa7bbSRiver Riddle Option<bool> annotateLoop{*this, "annotate",
73a70aa7bbSRiver Riddle llvm::cl::desc("Annotate unrolled iterations."),
74a70aa7bbSRiver Riddle llvm::cl::init(false)};
75a70aa7bbSRiver Riddle Option<bool> unrollUpToFactor{*this, "unroll-up-to-factor",
76a70aa7bbSRiver Riddle llvm::cl::desc("Loop unroll up to factor."),
77a70aa7bbSRiver Riddle llvm::cl::init(false)};
78a70aa7bbSRiver Riddle Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
79a70aa7bbSRiver Riddle llvm::cl::init(0)};
80a70aa7bbSRiver Riddle };
81a70aa7bbSRiver Riddle } // namespace
82a70aa7bbSRiver Riddle
83a70aa7bbSRiver Riddle namespace mlir {
84a70aa7bbSRiver Riddle namespace test {
registerTestLoopUnrollingPass()85a70aa7bbSRiver Riddle void registerTestLoopUnrollingPass() {
86a70aa7bbSRiver Riddle PassRegistration<TestLoopUnrollingPass>();
87a70aa7bbSRiver Riddle }
88a70aa7bbSRiver Riddle } // namespace test
89a70aa7bbSRiver Riddle } // namespace mlir
90