xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp (revision e8f07cdb57602d71f8960c0499765bcb000745c2)
1 //===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Rotates `scf.while` loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
14 
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
21   using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
22 
23   LogicalResult matchAndRewrite(scf::WhileOp whileOp,
24                                 PatternRewriter &rewriter) const final {
25     // Setting this option would lead to infinite recursion on a greedy driver
26     // as 'do-while' loops wouldn't be skipped.
27     constexpr bool forceCreateCheck = false;
28     FailureOr<scf::WhileOp> result =
29         scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
30     // scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure
31     // mechanism. 'do-while' loops are simply returned unmodified. In order to
32     // stop recursion, we check input and output operations differ.
33     return success(succeeded(result) && *result != whileOp);
34   }
35 };
36 } // namespace
37 
38 namespace mlir {
39 namespace scf {
40 void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
41   patterns.add<RotateWhileLoopPattern>(patterns.getContext());
42 }
43 } // namespace scf
44 } // namespace mlir
45