//===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Define conversions from the ControlFlow dialect to the SCF dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/CFGToSCF.h" namespace mlir { #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; FailureOr ControlFlowToSCFTransformation::createStructuredBranchRegionOp( OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, MutableArrayRef regions) { if (auto condBrOp = dyn_cast(controlFlowCondOp)) { assert(regions.size() == 2); auto ifOp = builder.create(controlFlowCondOp->getLoc(), resultTypes, condBrOp.getCondition()); ifOp.getThenRegion().takeBody(regions[0]); ifOp.getElseRegion().takeBody(regions[1]); return ifOp.getOperation(); } if (auto switchOp = dyn_cast(controlFlowCondOp)) { // `getCFGSwitchValue` returns an i32 that we need to convert to index // fist. auto cast = builder.create( controlFlowCondOp->getLoc(), builder.getIndexType(), switchOp.getFlag()); SmallVector cases; if (auto caseValues = switchOp.getCaseValues()) llvm::append_range( cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) { return apInt.getZExtValue(); })); assert(regions.size() == cases.size() + 1); auto indexSwitchOp = builder.create( controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); indexSwitchOp.getDefaultRegion().takeBody(regions[0]); for (auto &&[targetRegion, sourceRegion] : llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions))) targetRegion.takeBody(sourceRegion); return indexSwitchOp.getOperation(); } controlFlowCondOp->emitOpError( "Cannot convert unknown control flow op to structured control flow"); return failure(); } LogicalResult ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( Location loc, OpBuilder &builder, Operation *branchRegionOp, Operation *replacedControlFlowOp, ValueRange results) { builder.create(loc, results); return success(); } FailureOr ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) { Location loc = replacedOp->getLoc(); auto whileOp = builder.create(loc, loopVariablesInit.getTypes(), loopVariablesInit); whileOp.getBefore().takeBody(loopBody); builder.setInsertionPointToEnd(&whileOp.getBefore().back()); // `getCFGSwitchValue` returns a i32. We therefore need to truncate the // condition to i1 first. It is guaranteed to be either 0 or 1 already. builder.create( loc, builder.create(loc, builder.getI1Type(), condition), loopVariablesNextIter); Block *afterBlock = builder.createBlock(&whileOp.getAfter()); afterBlock->addArguments( loopVariablesInit.getTypes(), SmallVector(loopVariablesInit.size(), loc)); builder.create(loc, afterBlock->getArguments()); return whileOp.getOperation(); } Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc, OpBuilder &builder, unsigned int value) { return builder.create(loc, builder.getI32IntegerAttr(value)); } void ControlFlowToSCFTransformation::createCFGSwitchOp( Location loc, OpBuilder &builder, Value flag, ArrayRef caseValues, BlockRange caseDestinations, ArrayRef caseArguments, Block *defaultDest, ValueRange defaultArgs) { builder.create(loc, flag, defaultDest, defaultArgs, llvm::to_vector_of(caseValues), caseDestinations, caseArguments); } Value ControlFlowToSCFTransformation::getUndefValue(Location loc, OpBuilder &builder, Type type) { return builder.create(loc, type, nullptr); } FailureOr ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, OpBuilder &builder, Region ®ion) { // TODO: This should create a `ub.unreachable` op. Once such an operation // exists to make the pass independent of the func dialect. For now just // return poison values. Operation *parentOp = region.getParentOp(); auto funcOp = dyn_cast(parentOp); if (!funcOp) return emitError(loc, "Cannot create unreachable terminator for '") << parentOp->getName() << "'"; return builder .create( loc, llvm::map_to_vector(funcOp.getResultTypes(), [&](Type type) { return getUndefValue(loc, builder, type); })) .getOperation(); } namespace { struct LiftControlFlowToSCF : public impl::LiftControlFlowToSCFPassBase { using Base::Base; void runOnOperation() override { ControlFlowToSCFTransformation transformation; bool changed = false; Operation *op = getOperation(); WalkResult result = op->walk([&](func::FuncOp funcOp) { if (funcOp.getBody().empty()) return WalkResult::advance(); auto &domInfo = funcOp != op ? getChildAnalysis(funcOp) : getAnalysis(); auto visitor = [&](Operation *innerOp) -> WalkResult { for (Region ® : innerOp->getRegions()) { FailureOr changedFunc = transformCFGToSCF(reg, transformation, domInfo); if (failed(changedFunc)) return WalkResult::interrupt(); changed |= *changedFunc; } return WalkResult::advance(); }; if (funcOp->walk(visitor).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); }); if (result.wasInterrupted()) return signalPassFailure(); if (!changed) markAllAnalysesPreserved(); } }; } // namespace