xref: /llvm-project/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp (revision bbc0e631d2d3facd5952aeafc7400761813acc3a)
13b45fe2eSMarkus Böck //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
23b45fe2eSMarkus Böck //
33b45fe2eSMarkus Böck // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43b45fe2eSMarkus Böck // See https://llvm.org/LICENSE.txt for license information.
53b45fe2eSMarkus Böck // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63b45fe2eSMarkus Böck //
73b45fe2eSMarkus Böck //===----------------------------------------------------------------------===//
83b45fe2eSMarkus Böck //
93b45fe2eSMarkus Böck // Define conversions from the ControlFlow dialect to the SCF dialect.
103b45fe2eSMarkus Böck //
113b45fe2eSMarkus Böck //===----------------------------------------------------------------------===//
123b45fe2eSMarkus Böck 
133b45fe2eSMarkus Böck #include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h"
143b45fe2eSMarkus Böck 
153b45fe2eSMarkus Böck #include "mlir/Dialect/Arith/IR/Arith.h"
163b45fe2eSMarkus Böck #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
173b45fe2eSMarkus Böck #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
183b45fe2eSMarkus Böck #include "mlir/Dialect/Func/IR/FuncOps.h"
193b45fe2eSMarkus Böck #include "mlir/Dialect/SCF/IR/SCF.h"
203b45fe2eSMarkus Böck #include "mlir/Dialect/UB/IR/UBOps.h"
213b45fe2eSMarkus Böck #include "mlir/Pass/Pass.h"
223b45fe2eSMarkus Böck #include "mlir/Transforms/CFGToSCF.h"
233b45fe2eSMarkus Böck 
243b45fe2eSMarkus Böck namespace mlir {
253b45fe2eSMarkus Böck #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
263b45fe2eSMarkus Böck #include "mlir/Conversion/Passes.h.inc"
273b45fe2eSMarkus Böck } // namespace mlir
283b45fe2eSMarkus Böck 
293b45fe2eSMarkus Böck using namespace mlir;
303b45fe2eSMarkus Böck 
3190b25562SMarkus Böck FailureOr<Operation *>
3290b25562SMarkus Böck ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
333b45fe2eSMarkus Böck     OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
3490b25562SMarkus Böck     MutableArrayRef<Region> regions) {
353b45fe2eSMarkus Böck   if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
363b45fe2eSMarkus Böck     assert(regions.size() == 2);
3790b25562SMarkus Böck     auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
3890b25562SMarkus Böck                                           resultTypes, condBrOp.getCondition());
393b45fe2eSMarkus Böck     ifOp.getThenRegion().takeBody(regions[0]);
403b45fe2eSMarkus Böck     ifOp.getElseRegion().takeBody(regions[1]);
413b45fe2eSMarkus Böck     return ifOp.getOperation();
423b45fe2eSMarkus Böck   }
433b45fe2eSMarkus Böck 
443b45fe2eSMarkus Böck   if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
453b45fe2eSMarkus Böck     // `getCFGSwitchValue` returns an i32 that we need to convert to index
463b45fe2eSMarkus Böck     // fist.
473b45fe2eSMarkus Böck     auto cast = builder.create<arith::IndexCastUIOp>(
483b45fe2eSMarkus Böck         controlFlowCondOp->getLoc(), builder.getIndexType(),
493b45fe2eSMarkus Böck         switchOp.getFlag());
503b45fe2eSMarkus Böck     SmallVector<int64_t> cases;
513b45fe2eSMarkus Böck     if (auto caseValues = switchOp.getCaseValues())
523b45fe2eSMarkus Böck       llvm::append_range(
533b45fe2eSMarkus Böck           cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) {
543b45fe2eSMarkus Böck             return apInt.getZExtValue();
553b45fe2eSMarkus Böck           }));
563b45fe2eSMarkus Böck 
573b45fe2eSMarkus Böck     assert(regions.size() == cases.size() + 1);
583b45fe2eSMarkus Böck 
593b45fe2eSMarkus Böck     auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
603b45fe2eSMarkus Böck         controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
613b45fe2eSMarkus Böck 
623b45fe2eSMarkus Böck     indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
633b45fe2eSMarkus Böck     for (auto &&[targetRegion, sourceRegion] :
643b45fe2eSMarkus Böck          llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions)))
653b45fe2eSMarkus Böck       targetRegion.takeBody(sourceRegion);
663b45fe2eSMarkus Böck 
673b45fe2eSMarkus Böck     return indexSwitchOp.getOperation();
683b45fe2eSMarkus Böck   }
693b45fe2eSMarkus Böck 
703b45fe2eSMarkus Böck   controlFlowCondOp->emitOpError(
713b45fe2eSMarkus Böck       "Cannot convert unknown control flow op to structured control flow");
723b45fe2eSMarkus Böck   return failure();
733b45fe2eSMarkus Böck }
743b45fe2eSMarkus Böck 
753b45fe2eSMarkus Böck LogicalResult
7690b25562SMarkus Böck ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
7790b25562SMarkus Böck     Location loc, OpBuilder &builder, Operation *branchRegionOp,
78359ba0b0SMarkus Böck     Operation *replacedControlFlowOp, ValueRange results) {
793b45fe2eSMarkus Böck   builder.create<scf::YieldOp>(loc, results);
803b45fe2eSMarkus Böck   return success();
813b45fe2eSMarkus Böck }
823b45fe2eSMarkus Böck 
833b45fe2eSMarkus Böck FailureOr<Operation *>
8490b25562SMarkus Böck ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
8590b25562SMarkus Böck     OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
8690b25562SMarkus Böck     Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
873b45fe2eSMarkus Böck   Location loc = replacedOp->getLoc();
8890b25562SMarkus Böck   auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
8990b25562SMarkus Böck                                               loopVariablesInit);
903b45fe2eSMarkus Böck 
913b45fe2eSMarkus Böck   whileOp.getBefore().takeBody(loopBody);
923b45fe2eSMarkus Böck 
933b45fe2eSMarkus Böck   builder.setInsertionPointToEnd(&whileOp.getBefore().back());
943b45fe2eSMarkus Böck   // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
953b45fe2eSMarkus Böck   // condition to i1 first. It is guaranteed to be either 0 or 1 already.
963b45fe2eSMarkus Böck   builder.create<scf::ConditionOp>(
9790b25562SMarkus Böck       loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
983b45fe2eSMarkus Böck       loopVariablesNextIter);
993b45fe2eSMarkus Böck 
100*91d5653eSMatthias Springer   Block *afterBlock = builder.createBlock(&whileOp.getAfter());
1013b45fe2eSMarkus Böck   afterBlock->addArguments(
1023b45fe2eSMarkus Böck       loopVariablesInit.getTypes(),
1033b45fe2eSMarkus Böck       SmallVector<Location>(loopVariablesInit.size(), loc));
1043b45fe2eSMarkus Böck   builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
1053b45fe2eSMarkus Böck 
1063b45fe2eSMarkus Böck   return whileOp.getOperation();
1073b45fe2eSMarkus Böck }
1083b45fe2eSMarkus Böck 
10990b25562SMarkus Böck Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc,
11090b25562SMarkus Böck                                                         OpBuilder &builder,
11190b25562SMarkus Böck                                                         unsigned int value) {
1123b45fe2eSMarkus Böck   return builder.create<arith::ConstantOp>(loc,
1133b45fe2eSMarkus Böck                                            builder.getI32IntegerAttr(value));
1143b45fe2eSMarkus Böck }
1153b45fe2eSMarkus Böck 
11690b25562SMarkus Böck void ControlFlowToSCFTransformation::createCFGSwitchOp(
11790b25562SMarkus Böck     Location loc, OpBuilder &builder, Value flag,
11890b25562SMarkus Böck     ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
1193b45fe2eSMarkus Böck     ArrayRef<ValueRange> caseArguments, Block *defaultDest,
12090b25562SMarkus Böck     ValueRange defaultArgs) {
1213b45fe2eSMarkus Böck   builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
1223b45fe2eSMarkus Böck                                llvm::to_vector_of<int32_t>(caseValues),
1233b45fe2eSMarkus Böck                                caseDestinations, caseArguments);
1243b45fe2eSMarkus Böck }
1253b45fe2eSMarkus Böck 
12690b25562SMarkus Böck Value ControlFlowToSCFTransformation::getUndefValue(Location loc,
12790b25562SMarkus Böck                                                     OpBuilder &builder,
12890b25562SMarkus Böck                                                     Type type) {
1293b45fe2eSMarkus Böck   return builder.create<ub::PoisonOp>(loc, type, nullptr);
1303b45fe2eSMarkus Böck }
1313b45fe2eSMarkus Böck 
13290b25562SMarkus Böck FailureOr<Operation *>
13390b25562SMarkus Böck ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
1343b45fe2eSMarkus Böck                                                             OpBuilder &builder,
13590b25562SMarkus Böck                                                             Region &region) {
1363b45fe2eSMarkus Böck 
1373b45fe2eSMarkus Böck   // TODO: This should create a `ub.unreachable` op. Once such an operation
13890b25562SMarkus Böck   //       exists to make the pass independent of the func dialect. For now just
13990b25562SMarkus Böck   //       return poison values.
1400db1ae3eSIvan Butygin   Operation *parentOp = region.getParentOp();
1410db1ae3eSIvan Butygin   auto funcOp = dyn_cast<func::FuncOp>(parentOp);
1423b45fe2eSMarkus Böck   if (!funcOp)
1430db1ae3eSIvan Butygin     return emitError(loc, "Cannot create unreachable terminator for '")
1440db1ae3eSIvan Butygin            << parentOp->getName() << "'";
1453b45fe2eSMarkus Böck 
1463b45fe2eSMarkus Böck   return builder
1473b45fe2eSMarkus Böck       .create<func::ReturnOp>(
1483b45fe2eSMarkus Böck           loc, llvm::map_to_vector(funcOp.getResultTypes(),
1493b45fe2eSMarkus Böck                                    [&](Type type) {
1503b45fe2eSMarkus Böck                                      return getUndefValue(loc, builder, type);
1513b45fe2eSMarkus Böck                                    }))
1523b45fe2eSMarkus Böck       .getOperation();
1533b45fe2eSMarkus Böck }
15490b25562SMarkus Böck 
15590b25562SMarkus Böck namespace {
1563b45fe2eSMarkus Böck 
1573b45fe2eSMarkus Böck struct LiftControlFlowToSCF
1583b45fe2eSMarkus Böck     : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
1593b45fe2eSMarkus Böck 
1603b45fe2eSMarkus Böck   using Base::Base;
1613b45fe2eSMarkus Böck 
1623b45fe2eSMarkus Böck   void runOnOperation() override {
1633b45fe2eSMarkus Böck     ControlFlowToSCFTransformation transformation;
1643b45fe2eSMarkus Böck 
1653b45fe2eSMarkus Böck     bool changed = false;
1660db1ae3eSIvan Butygin     Operation *op = getOperation();
1670db1ae3eSIvan Butygin     WalkResult result = op->walk([&](func::FuncOp funcOp) {
1683b45fe2eSMarkus Böck       if (funcOp.getBody().empty())
1693b45fe2eSMarkus Böck         return WalkResult::advance();
1703b45fe2eSMarkus Böck 
1710db1ae3eSIvan Butygin       auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
1720db1ae3eSIvan Butygin                                    : getAnalysis<DominanceInfo>();
1730db1ae3eSIvan Butygin 
1740db1ae3eSIvan Butygin       auto visitor = [&](Operation *innerOp) -> WalkResult {
1750db1ae3eSIvan Butygin         for (Region &reg : innerOp->getRegions()) {
1760db1ae3eSIvan Butygin           FailureOr<bool> changedFunc =
1770db1ae3eSIvan Butygin               transformCFGToSCF(reg, transformation, domInfo);
1783b45fe2eSMarkus Böck           if (failed(changedFunc))
1793b45fe2eSMarkus Böck             return WalkResult::interrupt();
1803b45fe2eSMarkus Böck 
1813b45fe2eSMarkus Böck           changed |= *changedFunc;
1820db1ae3eSIvan Butygin         }
1830db1ae3eSIvan Butygin         return WalkResult::advance();
1840db1ae3eSIvan Butygin       };
1850db1ae3eSIvan Butygin 
1860db1ae3eSIvan Butygin       if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
1870db1ae3eSIvan Butygin         return WalkResult::interrupt();
1880db1ae3eSIvan Butygin 
1893b45fe2eSMarkus Böck       return WalkResult::advance();
1903b45fe2eSMarkus Böck     });
1913b45fe2eSMarkus Böck     if (result.wasInterrupted())
1923b45fe2eSMarkus Böck       return signalPassFailure();
1933b45fe2eSMarkus Böck 
1943b45fe2eSMarkus Böck     if (!changed)
1953b45fe2eSMarkus Böck       markAllAnalysesPreserved();
1963b45fe2eSMarkus Böck   }
1973b45fe2eSMarkus Böck };
1983b45fe2eSMarkus Böck } // namespace
199