xref: /llvm-project/mlir/test/lib/Dialect/SCF/TestLoopParametricTiling.cpp (revision caf8942cd9e52fca35992ab34af0a1cec1866759)
1a70aa7bbSRiver Riddle //===- TestLoopParametricTiling.cpp --- Parametric loop tiling pass -------===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle //
9a70aa7bbSRiver Riddle // This file implements a pass to parametrically tile nests of standard loops.
10a70aa7bbSRiver Riddle //
11a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
12a70aa7bbSRiver Riddle 
138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
14f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
15a70aa7bbSRiver Riddle #include "mlir/IR/Builders.h"
16a70aa7bbSRiver Riddle #include "mlir/Pass/Pass.h"
17a70aa7bbSRiver Riddle 
18a70aa7bbSRiver Riddle using namespace mlir;
19a70aa7bbSRiver Riddle 
20a70aa7bbSRiver Riddle namespace {
21a70aa7bbSRiver Riddle 
22a70aa7bbSRiver Riddle // Extracts fixed-range loops for top-level loop nests with ranges defined in
23a70aa7bbSRiver Riddle // the pass constructor.  Assumes loops are permutable.
24a70aa7bbSRiver Riddle class SimpleParametricLoopTilingPass
2587d6bf37SRiver Riddle     : public PassWrapper<SimpleParametricLoopTilingPass, OperationPass<>> {
26a70aa7bbSRiver Riddle public:
275e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimpleParametricLoopTilingPass)
285e50dd04SRiver Riddle 
29a70aa7bbSRiver Riddle   StringRef getArgument() const final {
30a70aa7bbSRiver Riddle     return "test-extract-fixed-outer-loops";
31a70aa7bbSRiver Riddle   }
32a70aa7bbSRiver Riddle   StringRef getDescription() const final {
33a70aa7bbSRiver Riddle     return "test application of parametric tiling to the outer loops so that "
34a70aa7bbSRiver Riddle            "the ranges of outer loops become static";
35a70aa7bbSRiver Riddle   }
36a70aa7bbSRiver Riddle   SimpleParametricLoopTilingPass() = default;
37a70aa7bbSRiver Riddle   SimpleParametricLoopTilingPass(const SimpleParametricLoopTilingPass &) {}
38a70aa7bbSRiver Riddle   explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes) {
39a70aa7bbSRiver Riddle     sizes = outerLoopSizes;
40a70aa7bbSRiver Riddle   }
41a70aa7bbSRiver Riddle 
42a70aa7bbSRiver Riddle   void runOnOperation() override {
43*caf8942cSKai Sasaki     if (sizes.empty()) {
44*caf8942cSKai Sasaki       emitError(
45*caf8942cSKai Sasaki           UnknownLoc::get(&getContext()),
46*caf8942cSKai Sasaki           "missing `test-outer-loop-sizes` pass-option for outer loop sizes");
47*caf8942cSKai Sasaki       signalPassFailure();
48*caf8942cSKai Sasaki       return;
49*caf8942cSKai Sasaki     }
5087d6bf37SRiver Riddle     getOperation()->walk([this](scf::ForOp op) {
51a70aa7bbSRiver Riddle       // Ignore nested loops.
52a70aa7bbSRiver Riddle       if (op->getParentRegion()->getParentOfType<scf::ForOp>())
53a70aa7bbSRiver Riddle         return;
54a70aa7bbSRiver Riddle       extractFixedOuterLoops(op, sizes);
55a70aa7bbSRiver Riddle     });
56a70aa7bbSRiver Riddle   }
57a70aa7bbSRiver Riddle 
58a70aa7bbSRiver Riddle   ListOption<int64_t> sizes{
596edef135SRiver Riddle       *this, "test-outer-loop-sizes",
60a70aa7bbSRiver Riddle       llvm::cl::desc(
61a70aa7bbSRiver Riddle           "fixed number of iterations that the outer loops should have")};
62a70aa7bbSRiver Riddle };
63a70aa7bbSRiver Riddle } // namespace
64a70aa7bbSRiver Riddle 
65a70aa7bbSRiver Riddle namespace mlir {
66a70aa7bbSRiver Riddle namespace test {
67a70aa7bbSRiver Riddle void registerSimpleParametricTilingPass() {
68a70aa7bbSRiver Riddle   PassRegistration<SimpleParametricLoopTilingPass>();
69a70aa7bbSRiver Riddle }
70a70aa7bbSRiver Riddle } // namespace test
71a70aa7bbSRiver Riddle } // namespace mlir
72