1e8f07cdbSVictor Perez //===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===// 2f7201505SJerry Wu // 3f7201505SJerry Wu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f7201505SJerry Wu // See https://llvm.org/LICENSE.txt for license information. 5f7201505SJerry Wu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f7201505SJerry Wu // 7f7201505SJerry Wu //===----------------------------------------------------------------------===// 8f7201505SJerry Wu // 9f7201505SJerry Wu // This file implements the passes to test wrap-in-zero-trip-check transforms on 10f7201505SJerry Wu // SCF loop ops. 11f7201505SJerry Wu // 12f7201505SJerry Wu //===----------------------------------------------------------------------===// 13f7201505SJerry Wu 14f7201505SJerry Wu #include "mlir/Dialect/Func/IR/FuncOps.h" 15f7201505SJerry Wu #include "mlir/Dialect/SCF/IR/SCF.h" 16e8f07cdbSVictor Perez #include "mlir/Dialect/SCF/Transforms/Patterns.h" 17f7201505SJerry Wu #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18f7201505SJerry Wu #include "mlir/IR/PatternMatch.h" 19f7201505SJerry Wu #include "mlir/Pass/Pass.h" 20e8f07cdbSVictor Perez #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21f7201505SJerry Wu 22f7201505SJerry Wu using namespace mlir; 23f7201505SJerry Wu 24f7201505SJerry Wu namespace { 25f7201505SJerry Wu 26f7201505SJerry Wu struct TestWrapWhileLoopInZeroTripCheckPass 27f7201505SJerry Wu : public PassWrapper<TestWrapWhileLoopInZeroTripCheckPass, 28f7201505SJerry Wu OperationPass<func::FuncOp>> { 29f7201505SJerry Wu MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 30f7201505SJerry Wu TestWrapWhileLoopInZeroTripCheckPass) 31f7201505SJerry Wu 32f7201505SJerry Wu StringRef getArgument() const final { 33f7201505SJerry Wu return "test-wrap-scf-while-loop-in-zero-trip-check"; 34f7201505SJerry Wu } 35f7201505SJerry Wu 36f7201505SJerry Wu StringRef getDescription() const final { 37f7201505SJerry Wu return "test scf::wrapWhileLoopInZeroTripCheck"; 38f7201505SJerry Wu } 39f7201505SJerry Wu 40f7201505SJerry Wu TestWrapWhileLoopInZeroTripCheckPass() = default; 41f7201505SJerry Wu TestWrapWhileLoopInZeroTripCheckPass( 42f7201505SJerry Wu const TestWrapWhileLoopInZeroTripCheckPass &) {} 43f7201505SJerry Wu explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) { 44f7201505SJerry Wu forceCreateCheck = forceCreateCheckParam; 45f7201505SJerry Wu } 46f7201505SJerry Wu 47f7201505SJerry Wu void runOnOperation() override { 48f7201505SJerry Wu func::FuncOp func = getOperation(); 49f7201505SJerry Wu MLIRContext *context = &getContext(); 50f7201505SJerry Wu IRRewriter rewriter(context); 51e8f07cdbSVictor Perez if (forceCreateCheck) { 52f7201505SJerry Wu func.walk([&](scf::WhileOp op) { 53f7201505SJerry Wu FailureOr<scf::WhileOp> result = 54f7201505SJerry Wu scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck); 55f7201505SJerry Wu // Ignore not implemented failure in tests. The expected output should 56f7201505SJerry Wu // catch problems (e.g. transformation doesn't happen). 57f7201505SJerry Wu (void)result; 58f7201505SJerry Wu }); 59e8f07cdbSVictor Perez } else { 60e8f07cdbSVictor Perez RewritePatternSet patterns(context); 61e8f07cdbSVictor Perez scf::populateSCFRotateWhileLoopPatterns(patterns); 62*09dfc571SJacques Pienaar (void)applyPatternsGreedily(func, std::move(patterns)); 63e8f07cdbSVictor Perez } 64f7201505SJerry Wu } 65f7201505SJerry Wu 66f7201505SJerry Wu Option<bool> forceCreateCheck{ 67f7201505SJerry Wu *this, "force-create-check", 68f7201505SJerry Wu llvm::cl::desc("Force to create zero-trip-check."), 69f7201505SJerry Wu llvm::cl::init(false)}; 70f7201505SJerry Wu }; 71f7201505SJerry Wu 72f7201505SJerry Wu } // namespace 73f7201505SJerry Wu 74f7201505SJerry Wu namespace mlir { 75f7201505SJerry Wu namespace test { 76f7201505SJerry Wu void registerTestSCFWrapInZeroTripCheckPasses() { 77f7201505SJerry Wu PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>(); 78f7201505SJerry Wu } 79f7201505SJerry Wu } // namespace test 80f7201505SJerry Wu } // namespace mlir 81