1 //===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Rotates `scf.while` loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 14 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 17 using namespace mlir; 18 19 namespace { 20 struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> { 21 using OpRewritePattern<scf::WhileOp>::OpRewritePattern; 22 23 LogicalResult matchAndRewrite(scf::WhileOp whileOp, 24 PatternRewriter &rewriter) const final { 25 // Setting this option would lead to infinite recursion on a greedy driver 26 // as 'do-while' loops wouldn't be skipped. 27 constexpr bool forceCreateCheck = false; 28 FailureOr<scf::WhileOp> result = 29 scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck); 30 // scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure 31 // mechanism. 'do-while' loops are simply returned unmodified. In order to 32 // stop recursion, we check input and output operations differ. 33 return success(succeeded(result) && *result != whileOp); 34 } 35 }; 36 } // namespace 37 38 namespace mlir { 39 namespace scf { 40 void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) { 41 patterns.add<RotateWhileLoopPattern>(patterns.getContext()); 42 } 43 } // namespace scf 44 } // namespace mlir 45