1 //===-- ControlFlowConverter.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIRDialect.h" 10 #include "flang/Optimizer/Dialect/FIROps.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 13 #include "flang/Optimizer/Dialect/Support/KindMapping.h" 14 #include "flang/Optimizer/Support/InternalNames.h" 15 #include "flang/Optimizer/Support/TypeCode.h" 16 #include "flang/Optimizer/Transforms/Passes.h" 17 #include "flang/Runtime/derived-api.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/SmallSet.h" 24 #include "llvm/Support/CommandLine.h" 25 26 namespace fir { 27 #define GEN_PASS_DEF_CFGCONVERSION 28 #include "flang/Optimizer/Transforms/Passes.h.inc" 29 } // namespace fir 30 31 using namespace fir; 32 using namespace mlir; 33 34 namespace { 35 36 // Conversion of fir control ops to more primitive control-flow. 37 // 38 // FIR loops that cannot be converted to the affine dialect will remain as 39 // `fir.do_loop` operations. These can be converted to control-flow operations. 40 41 /// Convert `fir.do_loop` to CFG 42 class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> { 43 public: 44 using OpRewritePattern::OpRewritePattern; 45 46 CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW) 47 : mlir::OpRewritePattern<fir::DoLoopOp>(ctx), 48 forceLoopToExecuteOnce(forceLoopToExecuteOnce), setNSW(setNSW) {} 49 50 llvm::LogicalResult 51 matchAndRewrite(DoLoopOp loop, 52 mlir::PatternRewriter &rewriter) const override { 53 auto loc = loop.getLoc(); 54 mlir::arith::IntegerOverflowFlags flags{}; 55 if (setNSW) 56 flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw); 57 auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get( 58 rewriter.getContext(), flags); 59 60 // Create the start and end blocks that will wrap the DoLoopOp with an 61 // initalizer and an end point 62 auto *initBlock = rewriter.getInsertionBlock(); 63 auto initPos = rewriter.getInsertionPoint(); 64 auto *endBlock = rewriter.splitBlock(initBlock, initPos); 65 66 // Split the first DoLoopOp block in two parts. The part before will be the 67 // conditional block since it already has the induction variable and 68 // loop-carried values as arguments. 69 auto *conditionalBlock = &loop.getRegion().front(); 70 conditionalBlock->addArgument(rewriter.getIndexType(), loc); 71 auto *firstBlock = 72 rewriter.splitBlock(conditionalBlock, conditionalBlock->begin()); 73 auto *lastBlock = &loop.getRegion().back(); 74 75 // Move the blocks from the DoLoopOp between initBlock and endBlock 76 rewriter.inlineRegionBefore(loop.getRegion(), endBlock); 77 78 // Get loop values from the DoLoopOp 79 auto low = loop.getLowerBound(); 80 auto high = loop.getUpperBound(); 81 assert(low && high && "must be a Value"); 82 auto step = loop.getStep(); 83 84 // Initalization block 85 rewriter.setInsertionPointToEnd(initBlock); 86 auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low); 87 auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step); 88 mlir::Value iters = 89 rewriter.create<mlir::arith::DivSIOp>(loc, distance, step); 90 91 if (forceLoopToExecuteOnce) { 92 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 93 auto cond = rewriter.create<mlir::arith::CmpIOp>( 94 loc, arith::CmpIPredicate::sle, iters, zero); 95 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 96 iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters); 97 } 98 99 llvm::SmallVector<mlir::Value> loopOperands; 100 loopOperands.push_back(low); 101 auto operands = loop.getIterOperands(); 102 loopOperands.append(operands.begin(), operands.end()); 103 loopOperands.push_back(iters); 104 105 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands); 106 107 // Last loop block 108 auto *terminator = lastBlock->getTerminator(); 109 rewriter.setInsertionPointToEnd(lastBlock); 110 auto iv = conditionalBlock->getArgument(0); 111 mlir::Value steppedIndex = 112 rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr); 113 assert(steppedIndex && "must be a Value"); 114 auto lastArg = conditionalBlock->getNumArguments() - 1; 115 auto itersLeft = conditionalBlock->getArgument(lastArg); 116 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 117 mlir::Value itersMinusOne = 118 rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one); 119 120 llvm::SmallVector<mlir::Value> loopCarried; 121 loopCarried.push_back(steppedIndex); 122 auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin()) 123 : terminator->operand_begin(); 124 loopCarried.append(begin, terminator->operand_end()); 125 loopCarried.push_back(itersMinusOne); 126 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried); 127 rewriter.eraseOp(terminator); 128 129 // Conditional block 130 rewriter.setInsertionPointToEnd(conditionalBlock); 131 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 132 auto comparison = rewriter.create<mlir::arith::CmpIOp>( 133 loc, arith::CmpIPredicate::sgt, itersLeft, zero); 134 135 auto cond = rewriter.create<mlir::cf::CondBranchOp>( 136 loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock, 137 llvm::ArrayRef<mlir::Value>()); 138 139 // Copy loop annotations from the do loop to the loop entry condition. 140 if (auto ann = loop.getLoopAnnotation()) 141 cond->setAttr("loop_annotation", *ann); 142 143 // The result of the loop operation is the values of the condition block 144 // arguments except the induction variable on the last iteration. 145 auto args = loop.getFinalValue() 146 ? conditionalBlock->getArguments() 147 : conditionalBlock->getArguments().drop_front(); 148 rewriter.replaceOp(loop, args.drop_back()); 149 return success(); 150 } 151 152 private: 153 bool forceLoopToExecuteOnce; 154 bool setNSW; 155 }; 156 157 /// Convert `fir.if` to control-flow 158 class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> { 159 public: 160 using OpRewritePattern::OpRewritePattern; 161 162 CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW) 163 : mlir::OpRewritePattern<fir::IfOp>(ctx) {} 164 165 llvm::LogicalResult 166 matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override { 167 auto loc = ifOp.getLoc(); 168 169 // Split the block containing the 'fir.if' into two parts. The part before 170 // will contain the condition, the part after will be the continuation 171 // point. 172 auto *condBlock = rewriter.getInsertionBlock(); 173 auto opPosition = rewriter.getInsertionPoint(); 174 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 175 mlir::Block *continueBlock; 176 if (ifOp.getNumResults() == 0) { 177 continueBlock = remainingOpsBlock; 178 } else { 179 continueBlock = rewriter.createBlock( 180 remainingOpsBlock, ifOp.getResultTypes(), 181 llvm::SmallVector<mlir::Location>(ifOp.getNumResults(), loc)); 182 rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock); 183 } 184 185 // Move blocks from the "then" region to the region containing 'fir.if', 186 // place it before the continuation block, and branch to it. 187 auto &ifOpRegion = ifOp.getThenRegion(); 188 auto *ifOpBlock = &ifOpRegion.front(); 189 auto *ifOpTerminator = ifOpRegion.back().getTerminator(); 190 auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); 191 rewriter.setInsertionPointToEnd(&ifOpRegion.back()); 192 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, 193 ifOpTerminatorOperands); 194 rewriter.eraseOp(ifOpTerminator); 195 rewriter.inlineRegionBefore(ifOpRegion, continueBlock); 196 197 // Move blocks from the "else" region (if present) to the region containing 198 // 'fir.if', place it before the continuation block and branch to it. It 199 // will be placed after the "then" regions. 200 auto *otherwiseBlock = continueBlock; 201 auto &otherwiseRegion = ifOp.getElseRegion(); 202 if (!otherwiseRegion.empty()) { 203 otherwiseBlock = &otherwiseRegion.front(); 204 auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); 205 auto otherwiseTermOperands = otherwiseTerm->getOperands(); 206 rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); 207 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, 208 otherwiseTermOperands); 209 rewriter.eraseOp(otherwiseTerm); 210 rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); 211 } 212 213 rewriter.setInsertionPointToEnd(condBlock); 214 rewriter.create<mlir::cf::CondBranchOp>( 215 loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(), 216 otherwiseBlock, llvm::ArrayRef<mlir::Value>()); 217 rewriter.replaceOp(ifOp, continueBlock->getArguments()); 218 return success(); 219 } 220 }; 221 222 /// Convert `fir.iter_while` to control-flow. 223 class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> { 224 public: 225 using OpRewritePattern::OpRewritePattern; 226 227 CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, 228 bool setNSW) 229 : mlir::OpRewritePattern<fir::IterWhileOp>(ctx), setNSW(setNSW) {} 230 231 llvm::LogicalResult 232 matchAndRewrite(fir::IterWhileOp whileOp, 233 mlir::PatternRewriter &rewriter) const override { 234 auto loc = whileOp.getLoc(); 235 mlir::arith::IntegerOverflowFlags flags{}; 236 if (setNSW) 237 flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw); 238 auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get( 239 rewriter.getContext(), flags); 240 241 // Start by splitting the block containing the 'fir.do_loop' into two parts. 242 // The part before will get the init code, the part after will be the end 243 // point. 244 auto *initBlock = rewriter.getInsertionBlock(); 245 auto initPosition = rewriter.getInsertionPoint(); 246 auto *endBlock = rewriter.splitBlock(initBlock, initPosition); 247 248 // Use the first block of the loop body as the condition block since it is 249 // the block that has the induction variable and loop-carried values as 250 // arguments. Split out all operations from the first block into a new 251 // block. Move all body blocks from the loop body region to the region 252 // containing the loop. 253 auto *conditionBlock = &whileOp.getRegion().front(); 254 auto *firstBodyBlock = 255 rewriter.splitBlock(conditionBlock, conditionBlock->begin()); 256 auto *lastBodyBlock = &whileOp.getRegion().back(); 257 rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock); 258 auto iv = conditionBlock->getArgument(0); 259 auto iterateVar = conditionBlock->getArgument(1); 260 261 // Append the induction variable stepping logic to the last body block and 262 // branch back to the condition block. Loop-carried values are taken from 263 // operands of the loop terminator. 264 auto *terminator = lastBodyBlock->getTerminator(); 265 rewriter.setInsertionPointToEnd(lastBodyBlock); 266 auto step = whileOp.getStep(); 267 mlir::Value stepped = 268 rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr); 269 assert(stepped && "must be a Value"); 270 271 llvm::SmallVector<mlir::Value> loopCarried; 272 loopCarried.push_back(stepped); 273 auto begin = whileOp.getFinalValue() 274 ? std::next(terminator->operand_begin()) 275 : terminator->operand_begin(); 276 loopCarried.append(begin, terminator->operand_end()); 277 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried); 278 rewriter.eraseOp(terminator); 279 280 // Compute loop bounds before branching to the condition. 281 rewriter.setInsertionPointToEnd(initBlock); 282 auto lowerBound = whileOp.getLowerBound(); 283 auto upperBound = whileOp.getUpperBound(); 284 assert(lowerBound && upperBound && "must be a Value"); 285 286 // The initial values of loop-carried values is obtained from the operands 287 // of the loop operation. 288 llvm::SmallVector<mlir::Value> destOperands; 289 destOperands.push_back(lowerBound); 290 auto iterOperands = whileOp.getIterOperands(); 291 destOperands.append(iterOperands.begin(), iterOperands.end()); 292 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands); 293 294 // With the body block done, we can fill in the condition block. 295 rewriter.setInsertionPointToEnd(conditionBlock); 296 // The comparison depends on the sign of the step value. We fully expect 297 // this expression to be folded by the optimizer or LLVM. This expression 298 // is written this way so that `step == 0` always returns `false`. 299 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 300 auto compl0 = rewriter.create<mlir::arith::CmpIOp>( 301 loc, arith::CmpIPredicate::slt, zero, step); 302 auto compl1 = rewriter.create<mlir::arith::CmpIOp>( 303 loc, arith::CmpIPredicate::sle, iv, upperBound); 304 auto compl2 = rewriter.create<mlir::arith::CmpIOp>( 305 loc, arith::CmpIPredicate::slt, step, zero); 306 auto compl3 = rewriter.create<mlir::arith::CmpIOp>( 307 loc, arith::CmpIPredicate::sle, upperBound, iv); 308 auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1); 309 auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3); 310 auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1); 311 // Remember to AND in the early-exit bool. 312 auto comparison = 313 rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2); 314 rewriter.create<mlir::cf::CondBranchOp>( 315 loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(), 316 endBlock, llvm::ArrayRef<mlir::Value>()); 317 // The result of the loop operation is the values of the condition block 318 // arguments except the induction variable on the last iteration. 319 auto args = whileOp.getFinalValue() 320 ? conditionBlock->getArguments() 321 : conditionBlock->getArguments().drop_front(); 322 rewriter.replaceOp(whileOp, args); 323 return success(); 324 } 325 326 private: 327 bool setNSW; 328 }; 329 330 /// Convert FIR structured control flow ops to CFG ops. 331 class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> { 332 public: 333 using CFGConversionBase<CfgConversion>::CFGConversionBase; 334 335 void runOnOperation() override { 336 auto *context = &this->getContext(); 337 mlir::RewritePatternSet patterns(context); 338 fir::populateCfgConversionRewrites(patterns, this->forceLoopToExecuteOnce, 339 this->setNSW); 340 mlir::ConversionTarget target(*context); 341 target.addLegalDialect<mlir::affine::AffineDialect, 342 mlir::cf::ControlFlowDialect, FIROpsDialect, 343 mlir::func::FuncDialect>(); 344 345 // apply the patterns 346 target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>(); 347 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 348 if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target, 349 std::move(patterns)))) { 350 mlir::emitError(mlir::UnknownLoc::get(context), 351 "error in converting to CFG\n"); 352 this->signalPassFailure(); 353 } 354 } 355 }; 356 357 } // namespace 358 359 /// Expose conversion rewriters to other passes 360 void fir::populateCfgConversionRewrites(mlir::RewritePatternSet &patterns, 361 bool forceLoopToExecuteOnce, 362 bool setNSW) { 363 patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( 364 patterns.getContext(), forceLoopToExecuteOnce, setNSW); 365 } 366