xref: /llvm-project/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1e8f07cdbSVictor Perez //===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===//
2f7201505SJerry Wu //
3f7201505SJerry Wu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f7201505SJerry Wu // See https://llvm.org/LICENSE.txt for license information.
5f7201505SJerry Wu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f7201505SJerry Wu //
7f7201505SJerry Wu //===----------------------------------------------------------------------===//
8f7201505SJerry Wu //
9f7201505SJerry Wu // This file implements the passes to test wrap-in-zero-trip-check transforms on
10f7201505SJerry Wu // SCF loop ops.
11f7201505SJerry Wu //
12f7201505SJerry Wu //===----------------------------------------------------------------------===//
13f7201505SJerry Wu 
14f7201505SJerry Wu #include "mlir/Dialect/Func/IR/FuncOps.h"
15f7201505SJerry Wu #include "mlir/Dialect/SCF/IR/SCF.h"
16e8f07cdbSVictor Perez #include "mlir/Dialect/SCF/Transforms/Patterns.h"
17f7201505SJerry Wu #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18f7201505SJerry Wu #include "mlir/IR/PatternMatch.h"
19f7201505SJerry Wu #include "mlir/Pass/Pass.h"
20e8f07cdbSVictor Perez #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21f7201505SJerry Wu 
22f7201505SJerry Wu using namespace mlir;
23f7201505SJerry Wu 
24f7201505SJerry Wu namespace {
25f7201505SJerry Wu 
26f7201505SJerry Wu struct TestWrapWhileLoopInZeroTripCheckPass
27f7201505SJerry Wu     : public PassWrapper<TestWrapWhileLoopInZeroTripCheckPass,
28f7201505SJerry Wu                          OperationPass<func::FuncOp>> {
29f7201505SJerry Wu   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
30f7201505SJerry Wu       TestWrapWhileLoopInZeroTripCheckPass)
31f7201505SJerry Wu 
32f7201505SJerry Wu   StringRef getArgument() const final {
33f7201505SJerry Wu     return "test-wrap-scf-while-loop-in-zero-trip-check";
34f7201505SJerry Wu   }
35f7201505SJerry Wu 
36f7201505SJerry Wu   StringRef getDescription() const final {
37f7201505SJerry Wu     return "test scf::wrapWhileLoopInZeroTripCheck";
38f7201505SJerry Wu   }
39f7201505SJerry Wu 
40f7201505SJerry Wu   TestWrapWhileLoopInZeroTripCheckPass() = default;
41f7201505SJerry Wu   TestWrapWhileLoopInZeroTripCheckPass(
42f7201505SJerry Wu       const TestWrapWhileLoopInZeroTripCheckPass &) {}
43f7201505SJerry Wu   explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) {
44f7201505SJerry Wu     forceCreateCheck = forceCreateCheckParam;
45f7201505SJerry Wu   }
46f7201505SJerry Wu 
47f7201505SJerry Wu   void runOnOperation() override {
48f7201505SJerry Wu     func::FuncOp func = getOperation();
49f7201505SJerry Wu     MLIRContext *context = &getContext();
50f7201505SJerry Wu     IRRewriter rewriter(context);
51e8f07cdbSVictor Perez     if (forceCreateCheck) {
52f7201505SJerry Wu       func.walk([&](scf::WhileOp op) {
53f7201505SJerry Wu         FailureOr<scf::WhileOp> result =
54f7201505SJerry Wu             scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
55f7201505SJerry Wu         // Ignore not implemented failure in tests. The expected output should
56f7201505SJerry Wu         // catch problems (e.g. transformation doesn't happen).
57f7201505SJerry Wu         (void)result;
58f7201505SJerry Wu       });
59e8f07cdbSVictor Perez     } else {
60e8f07cdbSVictor Perez       RewritePatternSet patterns(context);
61e8f07cdbSVictor Perez       scf::populateSCFRotateWhileLoopPatterns(patterns);
62*09dfc571SJacques Pienaar       (void)applyPatternsGreedily(func, std::move(patterns));
63e8f07cdbSVictor Perez     }
64f7201505SJerry Wu   }
65f7201505SJerry Wu 
66f7201505SJerry Wu   Option<bool> forceCreateCheck{
67f7201505SJerry Wu       *this, "force-create-check",
68f7201505SJerry Wu       llvm::cl::desc("Force to create zero-trip-check."),
69f7201505SJerry Wu       llvm::cl::init(false)};
70f7201505SJerry Wu };
71f7201505SJerry Wu 
72f7201505SJerry Wu } // namespace
73f7201505SJerry Wu 
74f7201505SJerry Wu namespace mlir {
75f7201505SJerry Wu namespace test {
76f7201505SJerry Wu void registerTestSCFWrapInZeroTripCheckPasses() {
77f7201505SJerry Wu   PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>();
78f7201505SJerry Wu }
79f7201505SJerry Wu } // namespace test
80f7201505SJerry Wu } // namespace mlir
81