xref: /llvm-project/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp (revision bbc0e631d2d3facd5952aeafc7400761813acc3a)
1 //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Define conversions from the ControlFlow dialect to the SCF dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/UB/IR/UBOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/CFGToSCF.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
26 #include "mlir/Conversion/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 
31 FailureOr<Operation *>
32 ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
33     OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
34     MutableArrayRef<Region> regions) {
35   if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
36     assert(regions.size() == 2);
37     auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
38                                           resultTypes, condBrOp.getCondition());
39     ifOp.getThenRegion().takeBody(regions[0]);
40     ifOp.getElseRegion().takeBody(regions[1]);
41     return ifOp.getOperation();
42   }
43 
44   if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
45     // `getCFGSwitchValue` returns an i32 that we need to convert to index
46     // fist.
47     auto cast = builder.create<arith::IndexCastUIOp>(
48         controlFlowCondOp->getLoc(), builder.getIndexType(),
49         switchOp.getFlag());
50     SmallVector<int64_t> cases;
51     if (auto caseValues = switchOp.getCaseValues())
52       llvm::append_range(
53           cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) {
54             return apInt.getZExtValue();
55           }));
56 
57     assert(regions.size() == cases.size() + 1);
58 
59     auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
60         controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
61 
62     indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
63     for (auto &&[targetRegion, sourceRegion] :
64          llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions)))
65       targetRegion.takeBody(sourceRegion);
66 
67     return indexSwitchOp.getOperation();
68   }
69 
70   controlFlowCondOp->emitOpError(
71       "Cannot convert unknown control flow op to structured control flow");
72   return failure();
73 }
74 
75 LogicalResult
76 ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
77     Location loc, OpBuilder &builder, Operation *branchRegionOp,
78     Operation *replacedControlFlowOp, ValueRange results) {
79   builder.create<scf::YieldOp>(loc, results);
80   return success();
81 }
82 
83 FailureOr<Operation *>
84 ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
85     OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
86     Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
87   Location loc = replacedOp->getLoc();
88   auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
89                                               loopVariablesInit);
90 
91   whileOp.getBefore().takeBody(loopBody);
92 
93   builder.setInsertionPointToEnd(&whileOp.getBefore().back());
94   // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
95   // condition to i1 first. It is guaranteed to be either 0 or 1 already.
96   builder.create<scf::ConditionOp>(
97       loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
98       loopVariablesNextIter);
99 
100   Block *afterBlock = builder.createBlock(&whileOp.getAfter());
101   afterBlock->addArguments(
102       loopVariablesInit.getTypes(),
103       SmallVector<Location>(loopVariablesInit.size(), loc));
104   builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
105 
106   return whileOp.getOperation();
107 }
108 
109 Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc,
110                                                         OpBuilder &builder,
111                                                         unsigned int value) {
112   return builder.create<arith::ConstantOp>(loc,
113                                            builder.getI32IntegerAttr(value));
114 }
115 
116 void ControlFlowToSCFTransformation::createCFGSwitchOp(
117     Location loc, OpBuilder &builder, Value flag,
118     ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
119     ArrayRef<ValueRange> caseArguments, Block *defaultDest,
120     ValueRange defaultArgs) {
121   builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
122                                llvm::to_vector_of<int32_t>(caseValues),
123                                caseDestinations, caseArguments);
124 }
125 
126 Value ControlFlowToSCFTransformation::getUndefValue(Location loc,
127                                                     OpBuilder &builder,
128                                                     Type type) {
129   return builder.create<ub::PoisonOp>(loc, type, nullptr);
130 }
131 
132 FailureOr<Operation *>
133 ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
134                                                             OpBuilder &builder,
135                                                             Region &region) {
136 
137   // TODO: This should create a `ub.unreachable` op. Once such an operation
138   //       exists to make the pass independent of the func dialect. For now just
139   //       return poison values.
140   Operation *parentOp = region.getParentOp();
141   auto funcOp = dyn_cast<func::FuncOp>(parentOp);
142   if (!funcOp)
143     return emitError(loc, "Cannot create unreachable terminator for '")
144            << parentOp->getName() << "'";
145 
146   return builder
147       .create<func::ReturnOp>(
148           loc, llvm::map_to_vector(funcOp.getResultTypes(),
149                                    [&](Type type) {
150                                      return getUndefValue(loc, builder, type);
151                                    }))
152       .getOperation();
153 }
154 
155 namespace {
156 
157 struct LiftControlFlowToSCF
158     : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
159 
160   using Base::Base;
161 
162   void runOnOperation() override {
163     ControlFlowToSCFTransformation transformation;
164 
165     bool changed = false;
166     Operation *op = getOperation();
167     WalkResult result = op->walk([&](func::FuncOp funcOp) {
168       if (funcOp.getBody().empty())
169         return WalkResult::advance();
170 
171       auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
172                                    : getAnalysis<DominanceInfo>();
173 
174       auto visitor = [&](Operation *innerOp) -> WalkResult {
175         for (Region &reg : innerOp->getRegions()) {
176           FailureOr<bool> changedFunc =
177               transformCFGToSCF(reg, transformation, domInfo);
178           if (failed(changedFunc))
179             return WalkResult::interrupt();
180 
181           changed |= *changedFunc;
182         }
183         return WalkResult::advance();
184       };
185 
186       if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
187         return WalkResult::interrupt();
188 
189       return WalkResult::advance();
190     });
191     if (result.wasInterrupted())
192       return signalPassFailure();
193 
194     if (!changed)
195       markAllAnalysesPreserved();
196   }
197 };
198 } // namespace
199