1 //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Define conversions from the ControlFlow dialect to the SCF dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/UB/IR/UBOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/CFGToSCF.h" 23 24 namespace mlir { 25 #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS 26 #include "mlir/Conversion/Passes.h.inc" 27 } // namespace mlir 28 29 using namespace mlir; 30 31 FailureOr<Operation *> 32 ControlFlowToSCFTransformation::createStructuredBranchRegionOp( 33 OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, 34 MutableArrayRef<Region> regions) { 35 if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) { 36 assert(regions.size() == 2); 37 auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(), 38 resultTypes, condBrOp.getCondition()); 39 ifOp.getThenRegion().takeBody(regions[0]); 40 ifOp.getElseRegion().takeBody(regions[1]); 41 return ifOp.getOperation(); 42 } 43 44 if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) { 45 // `getCFGSwitchValue` returns an i32 that we need to convert to index 46 // fist. 47 auto cast = builder.create<arith::IndexCastUIOp>( 48 controlFlowCondOp->getLoc(), builder.getIndexType(), 49 switchOp.getFlag()); 50 SmallVector<int64_t> cases; 51 if (auto caseValues = switchOp.getCaseValues()) 52 llvm::append_range( 53 cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) { 54 return apInt.getZExtValue(); 55 })); 56 57 assert(regions.size() == cases.size() + 1); 58 59 auto indexSwitchOp = builder.create<scf::IndexSwitchOp>( 60 controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); 61 62 indexSwitchOp.getDefaultRegion().takeBody(regions[0]); 63 for (auto &&[targetRegion, sourceRegion] : 64 llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions))) 65 targetRegion.takeBody(sourceRegion); 66 67 return indexSwitchOp.getOperation(); 68 } 69 70 controlFlowCondOp->emitOpError( 71 "Cannot convert unknown control flow op to structured control flow"); 72 return failure(); 73 } 74 75 LogicalResult 76 ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( 77 Location loc, OpBuilder &builder, Operation *branchRegionOp, 78 Operation *replacedControlFlowOp, ValueRange results) { 79 builder.create<scf::YieldOp>(loc, results); 80 return success(); 81 } 82 83 FailureOr<Operation *> 84 ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( 85 OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, 86 Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) { 87 Location loc = replacedOp->getLoc(); 88 auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(), 89 loopVariablesInit); 90 91 whileOp.getBefore().takeBody(loopBody); 92 93 builder.setInsertionPointToEnd(&whileOp.getBefore().back()); 94 // `getCFGSwitchValue` returns a i32. We therefore need to truncate the 95 // condition to i1 first. It is guaranteed to be either 0 or 1 already. 96 builder.create<scf::ConditionOp>( 97 loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition), 98 loopVariablesNextIter); 99 100 Block *afterBlock = builder.createBlock(&whileOp.getAfter()); 101 afterBlock->addArguments( 102 loopVariablesInit.getTypes(), 103 SmallVector<Location>(loopVariablesInit.size(), loc)); 104 builder.create<scf::YieldOp>(loc, afterBlock->getArguments()); 105 106 return whileOp.getOperation(); 107 } 108 109 Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc, 110 OpBuilder &builder, 111 unsigned int value) { 112 return builder.create<arith::ConstantOp>(loc, 113 builder.getI32IntegerAttr(value)); 114 } 115 116 void ControlFlowToSCFTransformation::createCFGSwitchOp( 117 Location loc, OpBuilder &builder, Value flag, 118 ArrayRef<unsigned int> caseValues, BlockRange caseDestinations, 119 ArrayRef<ValueRange> caseArguments, Block *defaultDest, 120 ValueRange defaultArgs) { 121 builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs, 122 llvm::to_vector_of<int32_t>(caseValues), 123 caseDestinations, caseArguments); 124 } 125 126 Value ControlFlowToSCFTransformation::getUndefValue(Location loc, 127 OpBuilder &builder, 128 Type type) { 129 return builder.create<ub::PoisonOp>(loc, type, nullptr); 130 } 131 132 FailureOr<Operation *> 133 ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, 134 OpBuilder &builder, 135 Region ®ion) { 136 137 // TODO: This should create a `ub.unreachable` op. Once such an operation 138 // exists to make the pass independent of the func dialect. For now just 139 // return poison values. 140 Operation *parentOp = region.getParentOp(); 141 auto funcOp = dyn_cast<func::FuncOp>(parentOp); 142 if (!funcOp) 143 return emitError(loc, "Cannot create unreachable terminator for '") 144 << parentOp->getName() << "'"; 145 146 return builder 147 .create<func::ReturnOp>( 148 loc, llvm::map_to_vector(funcOp.getResultTypes(), 149 [&](Type type) { 150 return getUndefValue(loc, builder, type); 151 })) 152 .getOperation(); 153 } 154 155 namespace { 156 157 struct LiftControlFlowToSCF 158 : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> { 159 160 using Base::Base; 161 162 void runOnOperation() override { 163 ControlFlowToSCFTransformation transformation; 164 165 bool changed = false; 166 Operation *op = getOperation(); 167 WalkResult result = op->walk([&](func::FuncOp funcOp) { 168 if (funcOp.getBody().empty()) 169 return WalkResult::advance(); 170 171 auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp) 172 : getAnalysis<DominanceInfo>(); 173 174 auto visitor = [&](Operation *innerOp) -> WalkResult { 175 for (Region ® : innerOp->getRegions()) { 176 FailureOr<bool> changedFunc = 177 transformCFGToSCF(reg, transformation, domInfo); 178 if (failed(changedFunc)) 179 return WalkResult::interrupt(); 180 181 changed |= *changedFunc; 182 } 183 return WalkResult::advance(); 184 }; 185 186 if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted()) 187 return WalkResult::interrupt(); 188 189 return WalkResult::advance(); 190 }); 191 if (result.wasInterrupted()) 192 return signalPassFailure(); 193 194 if (!changed) 195 markAllAnalysesPreserved(); 196 } 197 }; 198 } // namespace 199