1 //===-- AffinePromotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation is a prototype that promote FIR loops operations 10 // to affine dialect operations. 11 // It is not part of the production pipeline and would need more work in order 12 // to be used in production. 13 // More information can be found in this presentation: 14 // https://slides.com/rajanwalia/deck 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "PassDetail.h" 19 #include "flang/Optimizer/Dialect/FIRDialect.h" 20 #include "flang/Optimizer/Dialect/FIROps.h" 21 #include "flang/Optimizer/Dialect/FIRType.h" 22 #include "flang/Optimizer/Transforms/Passes.h" 23 #include "mlir/Dialect/Affine/IR/AffineOps.h" 24 #include "mlir/Dialect/Func/IR/FuncOps.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/IR/BuiltinAttributes.h" 27 #include "mlir/IR/IntegerSet.h" 28 #include "mlir/IR/Visitors.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/Optional.h" 32 #include "llvm/Support/Debug.h" 33 34 #define DEBUG_TYPE "flang-affine-promotion" 35 36 using namespace fir; 37 using namespace mlir; 38 39 namespace { 40 struct AffineLoopAnalysis; 41 struct AffineIfAnalysis; 42 43 /// Stores analysis objects for all loops and if operations inside a function 44 /// these analysis are used twice, first for marking operations for rewrite and 45 /// second when doing rewrite. 46 struct AffineFunctionAnalysis { 47 explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) { 48 for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>()) 49 loopAnalysisMap.try_emplace(op, op, *this); 50 } 51 52 AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const; 53 54 AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const; 55 56 llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap; 57 llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap; 58 }; 59 } // namespace 60 61 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) { 62 if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) { 63 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp())) 64 return true; 65 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a " 66 "loop induction variable (owner not loopOp)\n"; 67 op->dump()); 68 return false; 69 } 70 LLVM_DEBUG( 71 llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop " 72 "induction variable (not a block argument)\n"; 73 op->dump(); coordinate.getDefiningOp()->dump()); 74 return false; 75 } 76 77 namespace { 78 struct AffineLoopAnalysis { 79 AffineLoopAnalysis() = default; 80 81 explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa) 82 : legality(analyzeLoop(op, afa)) {} 83 84 bool canPromoteToAffine() { return legality; } 85 86 private: 87 bool analyzeBody(fir::DoLoopOp loopOperation, 88 AffineFunctionAnalysis &functionAnalysis) { 89 for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) { 90 auto analysis = functionAnalysis.loopAnalysisMap 91 .try_emplace(loopOp, loopOp, functionAnalysis) 92 .first->getSecond(); 93 if (!analysis.canPromoteToAffine()) 94 return false; 95 } 96 for (auto ifOp : loopOperation.getOps<fir::IfOp>()) 97 functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis); 98 return true; 99 } 100 101 bool analyzeLoop(fir::DoLoopOp loopOperation, 102 AffineFunctionAnalysis &functionAnalysis) { 103 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump();); 104 return analyzeMemoryAccess(loopOperation) && 105 analyzeBody(loopOperation, functionAnalysis); 106 } 107 108 bool analyzeReference(mlir::Value memref, mlir::Operation *op) { 109 if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) { 110 if (acoOp.getMemref().getType().isa<fir::BoxType>()) { 111 // TODO: Look if and how fir.box can be promoted to affine. 112 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, " 113 "array memory operation uses fir.box\n"; 114 op->dump(); acoOp.dump();); 115 return false; 116 } 117 bool canPromote = true; 118 for (auto coordinate : acoOp.getIndices()) 119 canPromote = canPromote && analyzeCoordinate(coordinate, op); 120 return canPromote; 121 } 122 if (auto coOp = memref.getDefiningOp<CoordinateOp>()) { 123 LLVM_DEBUG(llvm::dbgs() 124 << "AffineLoopAnalysis: cannot promote loop, " 125 "array memory operation uses non ArrayCoorOp\n"; 126 op->dump(); coOp.dump();); 127 128 return false; 129 } 130 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory " 131 "reference for array load\n"; 132 op->dump();); 133 return false; 134 } 135 136 bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) { 137 for (auto loadOp : loopOperation.getOps<fir::LoadOp>()) 138 if (!analyzeReference(loadOp.getMemref(), loadOp)) 139 return false; 140 for (auto storeOp : loopOperation.getOps<fir::StoreOp>()) 141 if (!analyzeReference(storeOp.getMemref(), storeOp)) 142 return false; 143 return true; 144 } 145 146 bool legality{}; 147 }; 148 } // namespace 149 150 AffineLoopAnalysis 151 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const { 152 auto it = loopAnalysisMap.find_as(op); 153 if (it == loopAnalysisMap.end()) { 154 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 155 op.dump();); 156 op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n"); 157 return {}; 158 } 159 return it->getSecond(); 160 } 161 162 namespace { 163 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the 164 /// final number of symbols and dimensions of the affine map. Integer set if 165 /// possible is in Optional IntegerSet. 166 struct AffineIfCondition { 167 using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>; 168 169 explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) { 170 if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>()) 171 fromCmpIOp(condDef); 172 } 173 174 bool hasIntegerSet() const { return integerSet.has_value(); } 175 176 mlir::IntegerSet getIntegerSet() const { 177 assert(hasIntegerSet() && "integer set is missing"); 178 return integerSet.getValue(); 179 } 180 181 mlir::ValueRange getAffineArgs() const { return affineArgs; } 182 183 private: 184 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs, 185 mlir::Value rhs) { 186 return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs)); 187 } 188 189 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs, 190 MaybeAffineExpr rhs) { 191 if (lhs && rhs) 192 return mlir::getAffineBinaryOpExpr(kind, lhs.getValue(), rhs.getValue()); 193 return {}; 194 } 195 196 MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; } 197 198 MaybeAffineExpr toAffineExpr(int64_t value) { 199 return {mlir::getAffineConstantExpr(value, firCondition.getContext())}; 200 } 201 202 /// Returns an AffineExpr if it is a result of operations that can be done 203 /// in an affine expression, this includes -, +, *, rem, constant. 204 /// block arguments of a loopOp or forOp are used as dimensions 205 MaybeAffineExpr toAffineExpr(mlir::Value value) { 206 if (auto op = value.getDefiningOp<mlir::arith::SubIOp>()) 207 return affineBinaryOp( 208 mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()), 209 affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()), 210 toAffineExpr(-1))); 211 if (auto op = value.getDefiningOp<mlir::arith::AddIOp>()) 212 return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(), 213 op.getRhs()); 214 if (auto op = value.getDefiningOp<mlir::arith::MulIOp>()) 215 return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(), 216 op.getRhs()); 217 if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>()) 218 return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(), 219 op.getRhs()); 220 if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>()) 221 if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>()) 222 return toAffineExpr(intConstant.getInt()); 223 if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) { 224 affineArgs.push_back(value); 225 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) || 226 isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp())) 227 return {mlir::getAffineDimExpr(dimCount++, value.getContext())}; 228 return {mlir::getAffineSymbolExpr(symCount++, value.getContext())}; 229 } 230 return {}; 231 } 232 233 void fromCmpIOp(mlir::arith::CmpIOp cmpOp) { 234 auto lhsAffine = toAffineExpr(cmpOp.getLhs()); 235 auto rhsAffine = toAffineExpr(cmpOp.getRhs()); 236 if (!lhsAffine || !rhsAffine) 237 return; 238 auto constraintPair = constraint( 239 cmpOp.getPredicate(), rhsAffine.getValue() - lhsAffine.getValue()); 240 if (!constraintPair) 241 return; 242 integerSet = mlir::IntegerSet::get(dimCount, symCount, 243 {constraintPair.getValue().first}, 244 {constraintPair.getValue().second}); 245 return; 246 } 247 248 llvm::Optional<std::pair<AffineExpr, bool>> 249 constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) { 250 switch (predicate) { 251 case mlir::arith::CmpIPredicate::slt: 252 return {std::make_pair(basic - 1, false)}; 253 case mlir::arith::CmpIPredicate::sle: 254 return {std::make_pair(basic, false)}; 255 case mlir::arith::CmpIPredicate::sgt: 256 return {std::make_pair(1 - basic, false)}; 257 case mlir::arith::CmpIPredicate::sge: 258 return {std::make_pair(0 - basic, false)}; 259 case mlir::arith::CmpIPredicate::eq: 260 return {std::make_pair(basic, true)}; 261 default: 262 return {}; 263 } 264 } 265 266 llvm::SmallVector<mlir::Value> affineArgs; 267 llvm::Optional<mlir::IntegerSet> integerSet; 268 mlir::Value firCondition; 269 unsigned symCount{0u}; 270 unsigned dimCount{0u}; 271 }; 272 } // namespace 273 274 namespace { 275 /// Analysis for affine promotion of fir.if 276 struct AffineIfAnalysis { 277 AffineIfAnalysis() = default; 278 279 explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa) 280 : legality(analyzeIf(op, afa)) {} 281 282 bool canPromoteToAffine() { return legality; } 283 284 private: 285 bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) { 286 if (op.getNumResults() == 0) 287 return true; 288 LLVM_DEBUG(llvm::dbgs() 289 << "AffineIfAnalysis: not promoting as op has results\n";); 290 return false; 291 } 292 293 bool legality{}; 294 }; 295 } // namespace 296 297 AffineIfAnalysis 298 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const { 299 auto it = ifAnalysisMap.find_as(op); 300 if (it == ifAnalysisMap.end()) { 301 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 302 op.dump();); 303 op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n"); 304 return {}; 305 } 306 return it->getSecond(); 307 } 308 309 /// AffineMap rewriting fir.array_coor operation to affine apply, 310 /// %dim = fir.gendim %lowerBound, %upperBound, %stride 311 /// %a = fir.array_coor %arr(%dim) %i 312 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)> 313 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions, 314 MLIRContext *context) { 315 auto index = mlir::getAffineConstantExpr(0, context); 316 auto accuExtent = mlir::getAffineConstantExpr(1, context); 317 for (unsigned i = 0; i < dimensions; ++i) { 318 mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context), 319 lowerBound = mlir::getAffineSymbolExpr(i * 3, context), 320 currentExtent = 321 mlir::getAffineSymbolExpr(i * 3 + 1, context), 322 stride = mlir::getAffineSymbolExpr(i * 3 + 2, context), 323 currentPart = (idx * stride - lowerBound) * accuExtent; 324 index = currentPart + index; 325 accuExtent = accuExtent * currentExtent; 326 } 327 return mlir::AffineMap::get(dimensions, dimensions * 3, index); 328 } 329 330 static Optional<int64_t> constantIntegerLike(const mlir::Value value) { 331 if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>()) 332 if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>()) 333 return stepAttr.getInt(); 334 return {}; 335 } 336 337 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) { 338 if (auto refType = 339 op.getMemref().getType().dyn_cast_or_null<ReferenceType>()) { 340 if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) { 341 return seqType.getEleTy(); 342 } 343 } 344 op.emitError( 345 "AffineLoopConversion: array type in coordinate operation not valid\n"); 346 return mlir::Type(); 347 } 348 349 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape, 350 SmallVectorImpl<mlir::Value> &indexArgs, 351 mlir::PatternRewriter &rewriter) { 352 auto one = rewriter.create<mlir::arith::ConstantOp>( 353 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 354 auto extents = shape.getExtents(); 355 for (auto i = extents.begin(); i < extents.end(); i++) { 356 indexArgs.push_back(one); 357 indexArgs.push_back(*i); 358 indexArgs.push_back(one); 359 } 360 } 361 362 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape, 363 SmallVectorImpl<mlir::Value> &indexArgs, 364 mlir::PatternRewriter &rewriter) { 365 auto one = rewriter.create<mlir::arith::ConstantOp>( 366 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 367 auto extents = shape.getPairs(); 368 for (auto i = extents.begin(); i < extents.end();) { 369 indexArgs.push_back(*i++); 370 indexArgs.push_back(*i++); 371 indexArgs.push_back(one); 372 } 373 } 374 375 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice, 376 SmallVectorImpl<mlir::Value> &indexArgs, 377 mlir::PatternRewriter &rewriter) { 378 auto extents = slice.getTriples(); 379 for (auto i = extents.begin(); i < extents.end();) { 380 indexArgs.push_back(*i++); 381 indexArgs.push_back(*i++); 382 indexArgs.push_back(*i++); 383 } 384 } 385 386 static void populateIndexArgs(fir::ArrayCoorOp acoOp, 387 SmallVectorImpl<mlir::Value> &indexArgs, 388 mlir::PatternRewriter &rewriter) { 389 if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>()) 390 return populateIndexArgs(acoOp, shape, indexArgs, rewriter); 391 if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>()) 392 return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter); 393 if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>()) 394 return populateIndexArgs(acoOp, slice, indexArgs, rewriter); 395 return; 396 } 397 398 /// Returns affine.apply and fir.convert from array_coor and gendims 399 static std::pair<mlir::AffineApplyOp, fir::ConvertOp> 400 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) { 401 auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>(); 402 auto affineMap = 403 createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext()); 404 SmallVector<mlir::Value> indexArgs; 405 indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end()); 406 407 populateIndexArgs(acoOp, indexArgs, rewriter); 408 409 auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(), 410 affineMap, indexArgs); 411 auto arrayElementType = coordinateArrayElement(acoOp); 412 auto newType = mlir::MemRefType::get({-1}, arrayElementType); 413 auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, 414 acoOp.getMemref()); 415 return std::make_pair(affineApply, arrayConvert); 416 } 417 418 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) { 419 rewriter.setInsertionPoint(loadOp); 420 auto affineOps = createAffineOps(loadOp.getMemref(), rewriter); 421 rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>( 422 loadOp, affineOps.second.getResult(), affineOps.first.getResult()); 423 } 424 425 static void rewriteStore(fir::StoreOp storeOp, 426 mlir::PatternRewriter &rewriter) { 427 rewriter.setInsertionPoint(storeOp); 428 auto affineOps = createAffineOps(storeOp.getMemref(), rewriter); 429 rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.getValue(), 430 affineOps.second.getResult(), 431 affineOps.first.getResult()); 432 } 433 434 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) { 435 for (auto &bodyOp : block->getOperations()) { 436 if (isa<fir::LoadOp>(bodyOp)) 437 rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter); 438 if (isa<fir::StoreOp>(bodyOp)) 439 rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter); 440 } 441 } 442 443 namespace { 444 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to 445 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir 446 /// loads and stores to affine. 447 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { 448 public: 449 using OpRewritePattern::OpRewritePattern; 450 AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 451 : OpRewritePattern(context), functionAnalysis(afa) {} 452 453 mlir::LogicalResult 454 matchAndRewrite(fir::DoLoopOp loop, 455 mlir::PatternRewriter &rewriter) const override { 456 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n"; 457 loop.dump();); 458 LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = 459 functionAnalysis.getChildLoopAnalysis(loop); 460 auto &loopOps = loop.getBody()->getOperations(); 461 auto loopAndIndex = createAffineFor(loop, rewriter); 462 auto affineFor = loopAndIndex.first; 463 auto inductionVar = loopAndIndex.second; 464 465 rewriter.startRootUpdate(affineFor.getOperation()); 466 affineFor.getBody()->getOperations().splice( 467 std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(), 468 std::prev(loopOps.end())); 469 rewriter.finalizeRootUpdate(affineFor.getOperation()); 470 471 rewriter.startRootUpdate(loop.getOperation()); 472 loop.getInductionVar().replaceAllUsesWith(inductionVar); 473 rewriter.finalizeRootUpdate(loop.getOperation()); 474 475 rewriteMemoryOps(affineFor.getBody(), rewriter); 476 477 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n"; 478 affineFor.dump();); 479 rewriter.replaceOp(loop, affineFor.getOperation()->getResults()); 480 return success(); 481 } 482 483 private: 484 std::pair<mlir::AffineForOp, mlir::Value> 485 createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 486 if (auto constantStep = constantIntegerLike(op.getStep())) 487 if (constantStep.getValue() > 0) 488 return positiveConstantStep(op, constantStep.getValue(), rewriter); 489 return genericBounds(op, rewriter); 490 } 491 492 // when step for the loop is positive compile time constant 493 std::pair<mlir::AffineForOp, mlir::Value> 494 positiveConstantStep(fir::DoLoopOp op, int64_t step, 495 mlir::PatternRewriter &rewriter) const { 496 auto affineFor = rewriter.create<mlir::AffineForOp>( 497 op.getLoc(), ValueRange(op.getLowerBound()), 498 mlir::AffineMap::get(0, 1, 499 mlir::getAffineSymbolExpr(0, op.getContext())), 500 ValueRange(op.getUpperBound()), 501 mlir::AffineMap::get(0, 1, 502 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 503 step); 504 return std::make_pair(affineFor, affineFor.getInductionVar()); 505 } 506 507 std::pair<mlir::AffineForOp, mlir::Value> 508 genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 509 auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext()); 510 auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext()); 511 auto step = mlir::getAffineSymbolExpr(2, op.getContext()); 512 mlir::AffineMap upperBoundMap = mlir::AffineMap::get( 513 0, 3, (upperBound - lowerBound + step).floorDiv(step)); 514 auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>( 515 op.getLoc(), upperBoundMap, 516 ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()})); 517 auto actualIndexMap = mlir::AffineMap::get( 518 1, 2, 519 (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) * 520 mlir::getAffineSymbolExpr(1, op.getContext())); 521 522 auto affineFor = rewriter.create<mlir::AffineForOp>( 523 op.getLoc(), ValueRange(), 524 AffineMap::getConstantMap(0, op.getContext()), 525 genericUpperBound.getResult(), 526 mlir::AffineMap::get(0, 1, 527 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 528 1); 529 rewriter.setInsertionPointToStart(affineFor.getBody()); 530 auto actualIndex = rewriter.create<mlir::AffineApplyOp>( 531 op.getLoc(), actualIndexMap, 532 ValueRange( 533 {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()})); 534 return std::make_pair(affineFor, actualIndex.getResult()); 535 } 536 537 AffineFunctionAnalysis &functionAnalysis; 538 }; 539 540 /// Convert `fir.if` to `affine.if`. 541 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> { 542 public: 543 using OpRewritePattern::OpRewritePattern; 544 AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 545 : OpRewritePattern(context) {} 546 mlir::LogicalResult 547 matchAndRewrite(fir::IfOp op, 548 mlir::PatternRewriter &rewriter) const override { 549 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n"; 550 op.dump();); 551 auto &ifOps = op.getThenRegion().front().getOperations(); 552 auto affineCondition = AffineIfCondition(op.getCondition()); 553 if (!affineCondition.hasIntegerSet()) { 554 LLVM_DEBUG( 555 llvm::dbgs() 556 << "AffineIfConversion: couldn't calculate affine condition\n";); 557 return failure(); 558 } 559 auto affineIf = rewriter.create<mlir::AffineIfOp>( 560 op.getLoc(), affineCondition.getIntegerSet(), 561 affineCondition.getAffineArgs(), !op.getElseRegion().empty()); 562 rewriter.startRootUpdate(affineIf); 563 affineIf.getThenBlock()->getOperations().splice( 564 std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(), 565 std::prev(ifOps.end())); 566 if (!op.getElseRegion().empty()) { 567 auto &otherOps = op.getElseRegion().front().getOperations(); 568 affineIf.getElseBlock()->getOperations().splice( 569 std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(), 570 std::prev(otherOps.end())); 571 } 572 rewriter.finalizeRootUpdate(affineIf); 573 rewriteMemoryOps(affineIf.getBody(), rewriter); 574 575 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n"; 576 affineIf.dump();); 577 rewriter.replaceOp(op, affineIf.getOperation()->getResults()); 578 return success(); 579 } 580 }; 581 582 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases 583 /// where such a promotion is possible. 584 class AffineDialectPromotion 585 : public AffineDialectPromotionBase<AffineDialectPromotion> { 586 public: 587 void runOnOperation() override { 588 589 auto *context = &getContext(); 590 auto function = getOperation(); 591 markAllAnalysesPreserved(); 592 auto functionAnalysis = AffineFunctionAnalysis(function); 593 mlir::RewritePatternSet patterns(context); 594 patterns.insert<AffineIfConversion>(context, functionAnalysis); 595 patterns.insert<AffineLoopConversion>(context, functionAnalysis); 596 mlir::ConversionTarget target = *context; 597 target.addLegalDialect< 598 mlir::AffineDialect, FIROpsDialect, mlir::scf::SCFDialect, 599 mlir::arith::ArithmeticDialect, mlir::func::FuncDialect>(); 600 target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) { 601 return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); 602 }); 603 target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis]( 604 fir::DoLoopOp op) { 605 return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine()); 606 }); 607 608 LLVM_DEBUG(llvm::dbgs() 609 << "AffineDialectPromotion: running promotion on: \n"; 610 function.print(llvm::dbgs());); 611 // apply the patterns 612 if (mlir::failed(mlir::applyPartialConversion(function, target, 613 std::move(patterns)))) { 614 mlir::emitError(mlir::UnknownLoc::get(context), 615 "error in converting to affine dialect\n"); 616 signalPassFailure(); 617 } 618 } 619 }; 620 } // namespace 621 622 /// Convert FIR loop constructs to the Affine dialect 623 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() { 624 return std::make_unique<AffineDialectPromotion>(); 625 } 626