xref: /llvm-project/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp (revision f49d069ac03baffd64b22792001ec40f973b3178)
1 //===- AffineLoopNormalize.cpp - AffineLoopNormalize Pass -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a normalizer for affine loop-like ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Passes.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/Utils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_AFFINELOOPNORMALIZE
21 #include "mlir/Dialect/Affine/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 /// Normalize affine.parallel ops so that lower bounds are 0 and steps are 1.
29 /// As currently implemented, this pass cannot fail, but it might skip over ops
30 /// that are already in a normalized form.
31 struct AffineLoopNormalizePass
32     : public impl::AffineLoopNormalizeBase<AffineLoopNormalizePass> {
33   explicit AffineLoopNormalizePass(bool promoteSingleIter) {
34     this->promoteSingleIter = promoteSingleIter;
35   }
36 
37   void runOnOperation() override {
38     getOperation().walk([&](Operation *op) {
39       if (auto affineParallel = dyn_cast<AffineParallelOp>(op))
40         normalizeAffineParallel(affineParallel);
41       else if (auto affineFor = dyn_cast<AffineForOp>(op))
42         (void)normalizeAffineFor(affineFor, promoteSingleIter);
43     });
44   }
45 };
46 
47 } // namespace
48 
49 std::unique_ptr<OperationPass<func::FuncOp>>
50 mlir::createAffineLoopNormalizePass(bool promoteSingleIter) {
51   return std::make_unique<AffineLoopNormalizePass>(promoteSingleIter);
52 }
53