1 //===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the passes to test wrap-in-zero-trip-check transforms on 10 // SCF loop ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 17 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct TestWrapWhileLoopInZeroTripCheckPass 27 : public PassWrapper<TestWrapWhileLoopInZeroTripCheckPass, 28 OperationPass<func::FuncOp>> { 29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 30 TestWrapWhileLoopInZeroTripCheckPass) 31 32 StringRef getArgument() const final { 33 return "test-wrap-scf-while-loop-in-zero-trip-check"; 34 } 35 36 StringRef getDescription() const final { 37 return "test scf::wrapWhileLoopInZeroTripCheck"; 38 } 39 40 TestWrapWhileLoopInZeroTripCheckPass() = default; 41 TestWrapWhileLoopInZeroTripCheckPass( 42 const TestWrapWhileLoopInZeroTripCheckPass &) {} 43 explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) { 44 forceCreateCheck = forceCreateCheckParam; 45 } 46 47 void runOnOperation() override { 48 func::FuncOp func = getOperation(); 49 MLIRContext *context = &getContext(); 50 IRRewriter rewriter(context); 51 if (forceCreateCheck) { 52 func.walk([&](scf::WhileOp op) { 53 FailureOr<scf::WhileOp> result = 54 scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck); 55 // Ignore not implemented failure in tests. The expected output should 56 // catch problems (e.g. transformation doesn't happen). 57 (void)result; 58 }); 59 } else { 60 RewritePatternSet patterns(context); 61 scf::populateSCFRotateWhileLoopPatterns(patterns); 62 (void)applyPatternsGreedily(func, std::move(patterns)); 63 } 64 } 65 66 Option<bool> forceCreateCheck{ 67 *this, "force-create-check", 68 llvm::cl::desc("Force to create zero-trip-check."), 69 llvm::cl::init(false)}; 70 }; 71 72 } // namespace 73 74 namespace mlir { 75 namespace test { 76 void registerTestSCFWrapInZeroTripCheckPasses() { 77 PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>(); 78 } 79 } // namespace test 80 } // namespace mlir 81