xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp (revision e8f07cdb57602d71f8960c0499765bcb000745c2)
1*e8f07cdbSVictor Perez //===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
2*e8f07cdbSVictor Perez //
3*e8f07cdbSVictor Perez // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*e8f07cdbSVictor Perez // See https://llvm.org/LICENSE.txt for license information.
5*e8f07cdbSVictor Perez // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*e8f07cdbSVictor Perez //
7*e8f07cdbSVictor Perez //===----------------------------------------------------------------------===//
8*e8f07cdbSVictor Perez //
9*e8f07cdbSVictor Perez // Rotates `scf.while` loops.
10*e8f07cdbSVictor Perez //
11*e8f07cdbSVictor Perez //===----------------------------------------------------------------------===//
12*e8f07cdbSVictor Perez 
13*e8f07cdbSVictor Perez #include "mlir/Dialect/SCF/Transforms/Patterns.h"
14*e8f07cdbSVictor Perez 
15*e8f07cdbSVictor Perez #include "mlir/Dialect/SCF/IR/SCF.h"
16*e8f07cdbSVictor Perez 
17*e8f07cdbSVictor Perez using namespace mlir;
18*e8f07cdbSVictor Perez 
19*e8f07cdbSVictor Perez namespace {
20*e8f07cdbSVictor Perez struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
21*e8f07cdbSVictor Perez   using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
22*e8f07cdbSVictor Perez 
23*e8f07cdbSVictor Perez   LogicalResult matchAndRewrite(scf::WhileOp whileOp,
24*e8f07cdbSVictor Perez                                 PatternRewriter &rewriter) const final {
25*e8f07cdbSVictor Perez     // Setting this option would lead to infinite recursion on a greedy driver
26*e8f07cdbSVictor Perez     // as 'do-while' loops wouldn't be skipped.
27*e8f07cdbSVictor Perez     constexpr bool forceCreateCheck = false;
28*e8f07cdbSVictor Perez     FailureOr<scf::WhileOp> result =
29*e8f07cdbSVictor Perez         scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
30*e8f07cdbSVictor Perez     // scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure
31*e8f07cdbSVictor Perez     // mechanism. 'do-while' loops are simply returned unmodified. In order to
32*e8f07cdbSVictor Perez     // stop recursion, we check input and output operations differ.
33*e8f07cdbSVictor Perez     return success(succeeded(result) && *result != whileOp);
34*e8f07cdbSVictor Perez   }
35*e8f07cdbSVictor Perez };
36*e8f07cdbSVictor Perez } // namespace
37*e8f07cdbSVictor Perez 
38*e8f07cdbSVictor Perez namespace mlir {
39*e8f07cdbSVictor Perez namespace scf {
40*e8f07cdbSVictor Perez void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
41*e8f07cdbSVictor Perez   patterns.add<RotateWhileLoopPattern>(patterns.getContext());
42*e8f07cdbSVictor Perez }
43*e8f07cdbSVictor Perez } // namespace scf
44*e8f07cdbSVictor Perez } // namespace mlir
45