xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp (revision f7201505a6ec7a0f904d2f09cece5c770058a991)
1 //===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/IR/SCF.h"
10 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/PatternMatch.h"
13 
14 using namespace mlir;
15 
16 /// Create zero-trip-check around a `while` op and return the new loop op in the
17 /// check. The while loop is rotated to avoid evaluating the condition twice.
18 ///
19 /// Given an example below:
20 ///
21 ///   scf.while (%arg0 = %init) : (i32) -> i64 {
22 ///     %val = .., %arg0 : i64
23 ///     %cond = arith.cmpi .., %arg0 : i32
24 ///     scf.condition(%cond) %val : i64
25 ///   } do {
26 ///   ^bb0(%arg1: i64):
27 ///     %next = .., %arg1 : i32
28 ///     scf.yield %next : i32
29 ///   }
30 ///
31 /// First clone before block to the front of the loop:
32 ///
33 ///   %pre_val = .., %init : i64
34 ///   %pre_cond = arith.cmpi .., %init : i32
35 ///   scf.while (%arg0 = %init) : (i32) -> i64 {
36 ///     %val = .., %arg0 : i64
37 ///     %cond = arith.cmpi .., %arg0 : i32
38 ///     scf.condition(%cond) %val : i64
39 ///   } do {
40 ///   ^bb0(%arg1: i64):
41 ///     %next = .., %arg1 : i32
42 ///     scf.yield %next : i32
43 ///   }
44 ///
45 /// Create `if` op with the condition, rotate and move the loop into the else
46 /// branch:
47 ///
48 ///   %pre_val = .., %init : i64
49 ///   %pre_cond = arith.cmpi .., %init : i32
50 ///   scf.if %pre_cond -> i64 {
51 ///     %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
52 ///       // Original after block
53 ///       %next = .., %arg1 : i32
54 ///       // Original before block
55 ///       %val = .., %next : i64
56 ///       %cond = arith.cmpi .., %next : i32
57 ///       scf.condition(%cond) %val : i64
58 ///     } do {
59 ///     ^bb0(%arg2: i64):
60 ///       %scf.yield %arg2 : i32
61 ///     }
62 ///     scf.yield %res : i64
63 ///   } else {
64 ///     scf.yield %pre_val : i64
65 ///   }
wrapWhileLoopInZeroTripCheck(scf::WhileOp whileOp,RewriterBase & rewriter,bool forceCreateCheck)66 FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck(
67     scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) {
68   // If the loop is in do-while form (after block only passes through values),
69   // there is no need to create a zero-trip-check as before block is always run.
70   if (!forceCreateCheck && isa<scf::YieldOp>(whileOp.getAfterBody()->front())) {
71     return whileOp;
72   }
73 
74   OpBuilder::InsertionGuard insertion_guard(rewriter);
75 
76   IRMapping mapper;
77   Block *beforeBlock = whileOp.getBeforeBody();
78   // Clone before block before the loop for zero-trip-check.
79   for (auto [arg, init] :
80        llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
81     mapper.map(arg, init);
82   }
83   rewriter.setInsertionPoint(whileOp);
84   for (auto &op : *beforeBlock) {
85     if (isa<scf::ConditionOp>(op)) {
86       break;
87     }
88     // Safe to clone everything as in a single block all defs have been cloned
89     // and added to mapper in order.
90     rewriter.insert(op.clone(mapper));
91   }
92 
93   scf::ConditionOp condOp = whileOp.getConditionOp();
94   Value clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
95   SmallVector<Value> clonedCondArgs = llvm::map_to_vector(
96       condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
97 
98   // Create rotated while loop.
99   auto newLoopOp = rewriter.create<scf::WhileOp>(
100       whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs,
101       [&](OpBuilder &builder, Location loc, ValueRange args) {
102         // Rotate and move the loop body into before block.
103         auto newBlock = builder.getBlock();
104         rewriter.mergeBlocks(whileOp.getAfterBody(), newBlock, args);
105         auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
106         rewriter.mergeBlocks(whileOp.getBeforeBody(), newBlock,
107                              yieldOp.getResults());
108         rewriter.eraseOp(yieldOp);
109       },
110       [&](OpBuilder &builder, Location loc, ValueRange args) {
111         // Pass through values.
112         builder.create<scf::YieldOp>(loc, args);
113       });
114 
115   // Create zero-trip-check and move the while loop in.
116   auto ifOp = rewriter.create<scf::IfOp>(
117       whileOp.getLoc(), clonedCondition,
118       [&](OpBuilder &builder, Location loc) {
119         // Then runs the while loop.
120         rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(),
121                               builder.getInsertionPoint());
122         builder.create<scf::YieldOp>(loc, newLoopOp.getResults());
123       },
124       [&](OpBuilder &builder, Location loc) {
125         // Else returns the results from precondition.
126         builder.create<scf::YieldOp>(loc, clonedCondArgs);
127       });
128 
129   rewriter.replaceOp(whileOp, ifOp);
130 
131   return newLoopOp;
132 }
133