1 //===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SCF/IR/SCF.h"
10 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/PatternMatch.h"
13
14 using namespace mlir;
15
16 /// Create zero-trip-check around a `while` op and return the new loop op in the
17 /// check. The while loop is rotated to avoid evaluating the condition twice.
18 ///
19 /// Given an example below:
20 ///
21 /// scf.while (%arg0 = %init) : (i32) -> i64 {
22 /// %val = .., %arg0 : i64
23 /// %cond = arith.cmpi .., %arg0 : i32
24 /// scf.condition(%cond) %val : i64
25 /// } do {
26 /// ^bb0(%arg1: i64):
27 /// %next = .., %arg1 : i32
28 /// scf.yield %next : i32
29 /// }
30 ///
31 /// First clone before block to the front of the loop:
32 ///
33 /// %pre_val = .., %init : i64
34 /// %pre_cond = arith.cmpi .., %init : i32
35 /// scf.while (%arg0 = %init) : (i32) -> i64 {
36 /// %val = .., %arg0 : i64
37 /// %cond = arith.cmpi .., %arg0 : i32
38 /// scf.condition(%cond) %val : i64
39 /// } do {
40 /// ^bb0(%arg1: i64):
41 /// %next = .., %arg1 : i32
42 /// scf.yield %next : i32
43 /// }
44 ///
45 /// Create `if` op with the condition, rotate and move the loop into the else
46 /// branch:
47 ///
48 /// %pre_val = .., %init : i64
49 /// %pre_cond = arith.cmpi .., %init : i32
50 /// scf.if %pre_cond -> i64 {
51 /// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
52 /// // Original after block
53 /// %next = .., %arg1 : i32
54 /// // Original before block
55 /// %val = .., %next : i64
56 /// %cond = arith.cmpi .., %next : i32
57 /// scf.condition(%cond) %val : i64
58 /// } do {
59 /// ^bb0(%arg2: i64):
60 /// %scf.yield %arg2 : i32
61 /// }
62 /// scf.yield %res : i64
63 /// } else {
64 /// scf.yield %pre_val : i64
65 /// }
wrapWhileLoopInZeroTripCheck(scf::WhileOp whileOp,RewriterBase & rewriter,bool forceCreateCheck)66 FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck(
67 scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) {
68 // If the loop is in do-while form (after block only passes through values),
69 // there is no need to create a zero-trip-check as before block is always run.
70 if (!forceCreateCheck && isa<scf::YieldOp>(whileOp.getAfterBody()->front())) {
71 return whileOp;
72 }
73
74 OpBuilder::InsertionGuard insertion_guard(rewriter);
75
76 IRMapping mapper;
77 Block *beforeBlock = whileOp.getBeforeBody();
78 // Clone before block before the loop for zero-trip-check.
79 for (auto [arg, init] :
80 llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
81 mapper.map(arg, init);
82 }
83 rewriter.setInsertionPoint(whileOp);
84 for (auto &op : *beforeBlock) {
85 if (isa<scf::ConditionOp>(op)) {
86 break;
87 }
88 // Safe to clone everything as in a single block all defs have been cloned
89 // and added to mapper in order.
90 rewriter.insert(op.clone(mapper));
91 }
92
93 scf::ConditionOp condOp = whileOp.getConditionOp();
94 Value clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
95 SmallVector<Value> clonedCondArgs = llvm::map_to_vector(
96 condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
97
98 // Create rotated while loop.
99 auto newLoopOp = rewriter.create<scf::WhileOp>(
100 whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs,
101 [&](OpBuilder &builder, Location loc, ValueRange args) {
102 // Rotate and move the loop body into before block.
103 auto newBlock = builder.getBlock();
104 rewriter.mergeBlocks(whileOp.getAfterBody(), newBlock, args);
105 auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
106 rewriter.mergeBlocks(whileOp.getBeforeBody(), newBlock,
107 yieldOp.getResults());
108 rewriter.eraseOp(yieldOp);
109 },
110 [&](OpBuilder &builder, Location loc, ValueRange args) {
111 // Pass through values.
112 builder.create<scf::YieldOp>(loc, args);
113 });
114
115 // Create zero-trip-check and move the while loop in.
116 auto ifOp = rewriter.create<scf::IfOp>(
117 whileOp.getLoc(), clonedCondition,
118 [&](OpBuilder &builder, Location loc) {
119 // Then runs the while loop.
120 rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(),
121 builder.getInsertionPoint());
122 builder.create<scf::YieldOp>(loc, newLoopOp.getResults());
123 },
124 [&](OpBuilder &builder, Location loc) {
125 // Else returns the results from precondition.
126 builder.create<scf::YieldOp>(loc, clonedCondArgs);
127 });
128
129 rewriter.replaceOp(whileOp, ifOp);
130
131 return newLoopOp;
132 }
133