1 //===-- AffinePromotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation is a prototype that promote FIR loops operations 10 // to affine dialect operations. 11 // It is not part of the production pipeline and would need more work in order 12 // to be used in production. 13 // More information can be found in this presentation: 14 // https://slides.com/rajanwalia/deck 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "flang/Optimizer/Dialect/FIRDialect.h" 19 #include "flang/Optimizer/Dialect/FIROps.h" 20 #include "flang/Optimizer/Dialect/FIRType.h" 21 #include "flang/Optimizer/Transforms/Passes.h" 22 #include "mlir/Dialect/Affine/IR/AffineOps.h" 23 #include "mlir/Dialect/Func/IR/FuncOps.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/IR/BuiltinAttributes.h" 26 #include "mlir/IR/IntegerSet.h" 27 #include "mlir/IR/Visitors.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/DenseMap.h" 30 #include "llvm/ADT/Optional.h" 31 #include "llvm/Support/Debug.h" 32 #include <optional> 33 34 namespace fir { 35 #define GEN_PASS_DEF_AFFINEDIALECTPROMOTION 36 #include "flang/Optimizer/Transforms/Passes.h.inc" 37 } // namespace fir 38 39 #define DEBUG_TYPE "flang-affine-promotion" 40 41 using namespace fir; 42 using namespace mlir; 43 44 namespace { 45 struct AffineLoopAnalysis; 46 struct AffineIfAnalysis; 47 48 /// Stores analysis objects for all loops and if operations inside a function 49 /// these analysis are used twice, first for marking operations for rewrite and 50 /// second when doing rewrite. 51 struct AffineFunctionAnalysis { 52 explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) { 53 for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>()) 54 loopAnalysisMap.try_emplace(op, op, *this); 55 } 56 57 AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const; 58 59 AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const; 60 61 llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap; 62 llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap; 63 }; 64 } // namespace 65 66 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) { 67 if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) { 68 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp())) 69 return true; 70 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a " 71 "loop induction variable (owner not loopOp)\n"; 72 op->dump()); 73 return false; 74 } 75 LLVM_DEBUG( 76 llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop " 77 "induction variable (not a block argument)\n"; 78 op->dump(); coordinate.getDefiningOp()->dump()); 79 return false; 80 } 81 82 namespace { 83 struct AffineLoopAnalysis { 84 AffineLoopAnalysis() = default; 85 86 explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa) 87 : legality(analyzeLoop(op, afa)) {} 88 89 bool canPromoteToAffine() { return legality; } 90 91 private: 92 bool analyzeBody(fir::DoLoopOp loopOperation, 93 AffineFunctionAnalysis &functionAnalysis) { 94 for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) { 95 auto analysis = functionAnalysis.loopAnalysisMap 96 .try_emplace(loopOp, loopOp, functionAnalysis) 97 .first->getSecond(); 98 if (!analysis.canPromoteToAffine()) 99 return false; 100 } 101 for (auto ifOp : loopOperation.getOps<fir::IfOp>()) 102 functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis); 103 return true; 104 } 105 106 bool analyzeLoop(fir::DoLoopOp loopOperation, 107 AffineFunctionAnalysis &functionAnalysis) { 108 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump();); 109 return analyzeMemoryAccess(loopOperation) && 110 analyzeBody(loopOperation, functionAnalysis); 111 } 112 113 bool analyzeReference(mlir::Value memref, mlir::Operation *op) { 114 if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) { 115 if (acoOp.getMemref().getType().isa<fir::BoxType>()) { 116 // TODO: Look if and how fir.box can be promoted to affine. 117 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, " 118 "array memory operation uses fir.box\n"; 119 op->dump(); acoOp.dump();); 120 return false; 121 } 122 bool canPromote = true; 123 for (auto coordinate : acoOp.getIndices()) 124 canPromote = canPromote && analyzeCoordinate(coordinate, op); 125 return canPromote; 126 } 127 if (auto coOp = memref.getDefiningOp<CoordinateOp>()) { 128 LLVM_DEBUG(llvm::dbgs() 129 << "AffineLoopAnalysis: cannot promote loop, " 130 "array memory operation uses non ArrayCoorOp\n"; 131 op->dump(); coOp.dump();); 132 133 return false; 134 } 135 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory " 136 "reference for array load\n"; 137 op->dump();); 138 return false; 139 } 140 141 bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) { 142 for (auto loadOp : loopOperation.getOps<fir::LoadOp>()) 143 if (!analyzeReference(loadOp.getMemref(), loadOp)) 144 return false; 145 for (auto storeOp : loopOperation.getOps<fir::StoreOp>()) 146 if (!analyzeReference(storeOp.getMemref(), storeOp)) 147 return false; 148 return true; 149 } 150 151 bool legality{}; 152 }; 153 } // namespace 154 155 AffineLoopAnalysis 156 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const { 157 auto it = loopAnalysisMap.find_as(op); 158 if (it == loopAnalysisMap.end()) { 159 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 160 op.dump();); 161 op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n"); 162 return {}; 163 } 164 return it->getSecond(); 165 } 166 167 namespace { 168 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the 169 /// final number of symbols and dimensions of the affine map. Integer set if 170 /// possible is in Optional IntegerSet. 171 struct AffineIfCondition { 172 using MaybeAffineExpr = std::optional<mlir::AffineExpr>; 173 174 explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) { 175 if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>()) 176 fromCmpIOp(condDef); 177 } 178 179 bool hasIntegerSet() const { return integerSet.has_value(); } 180 181 mlir::IntegerSet getIntegerSet() const { 182 assert(hasIntegerSet() && "integer set is missing"); 183 return *integerSet; 184 } 185 186 mlir::ValueRange getAffineArgs() const { return affineArgs; } 187 188 private: 189 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs, 190 mlir::Value rhs) { 191 return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs)); 192 } 193 194 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs, 195 MaybeAffineExpr rhs) { 196 if (lhs && rhs) 197 return mlir::getAffineBinaryOpExpr(kind, *lhs, *rhs); 198 return {}; 199 } 200 201 MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; } 202 203 MaybeAffineExpr toAffineExpr(int64_t value) { 204 return {mlir::getAffineConstantExpr(value, firCondition.getContext())}; 205 } 206 207 /// Returns an AffineExpr if it is a result of operations that can be done 208 /// in an affine expression, this includes -, +, *, rem, constant. 209 /// block arguments of a loopOp or forOp are used as dimensions 210 MaybeAffineExpr toAffineExpr(mlir::Value value) { 211 if (auto op = value.getDefiningOp<mlir::arith::SubIOp>()) 212 return affineBinaryOp( 213 mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()), 214 affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()), 215 toAffineExpr(-1))); 216 if (auto op = value.getDefiningOp<mlir::arith::AddIOp>()) 217 return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(), 218 op.getRhs()); 219 if (auto op = value.getDefiningOp<mlir::arith::MulIOp>()) 220 return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(), 221 op.getRhs()); 222 if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>()) 223 return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(), 224 op.getRhs()); 225 if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>()) 226 if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>()) 227 return toAffineExpr(intConstant.getInt()); 228 if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) { 229 affineArgs.push_back(value); 230 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) || 231 isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp())) 232 return {mlir::getAffineDimExpr(dimCount++, value.getContext())}; 233 return {mlir::getAffineSymbolExpr(symCount++, value.getContext())}; 234 } 235 return {}; 236 } 237 238 void fromCmpIOp(mlir::arith::CmpIOp cmpOp) { 239 auto lhsAffine = toAffineExpr(cmpOp.getLhs()); 240 auto rhsAffine = toAffineExpr(cmpOp.getRhs()); 241 if (!lhsAffine || !rhsAffine) 242 return; 243 auto constraintPair = 244 constraint(cmpOp.getPredicate(), *rhsAffine - *lhsAffine); 245 if (!constraintPair) 246 return; 247 integerSet = mlir::IntegerSet::get( 248 dimCount, symCount, {constraintPair->first}, {constraintPair->second}); 249 } 250 251 std::optional<std::pair<AffineExpr, bool>> 252 constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) { 253 switch (predicate) { 254 case mlir::arith::CmpIPredicate::slt: 255 return {std::make_pair(basic - 1, false)}; 256 case mlir::arith::CmpIPredicate::sle: 257 return {std::make_pair(basic, false)}; 258 case mlir::arith::CmpIPredicate::sgt: 259 return {std::make_pair(1 - basic, false)}; 260 case mlir::arith::CmpIPredicate::sge: 261 return {std::make_pair(0 - basic, false)}; 262 case mlir::arith::CmpIPredicate::eq: 263 return {std::make_pair(basic, true)}; 264 default: 265 return {}; 266 } 267 } 268 269 llvm::SmallVector<mlir::Value> affineArgs; 270 std::optional<mlir::IntegerSet> integerSet; 271 mlir::Value firCondition; 272 unsigned symCount{0u}; 273 unsigned dimCount{0u}; 274 }; 275 } // namespace 276 277 namespace { 278 /// Analysis for affine promotion of fir.if 279 struct AffineIfAnalysis { 280 AffineIfAnalysis() = default; 281 282 explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa) 283 : legality(analyzeIf(op, afa)) {} 284 285 bool canPromoteToAffine() { return legality; } 286 287 private: 288 bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) { 289 if (op.getNumResults() == 0) 290 return true; 291 LLVM_DEBUG(llvm::dbgs() 292 << "AffineIfAnalysis: not promoting as op has results\n";); 293 return false; 294 } 295 296 bool legality{}; 297 }; 298 } // namespace 299 300 AffineIfAnalysis 301 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const { 302 auto it = ifAnalysisMap.find_as(op); 303 if (it == ifAnalysisMap.end()) { 304 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 305 op.dump();); 306 op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n"); 307 return {}; 308 } 309 return it->getSecond(); 310 } 311 312 /// AffineMap rewriting fir.array_coor operation to affine apply, 313 /// %dim = fir.gendim %lowerBound, %upperBound, %stride 314 /// %a = fir.array_coor %arr(%dim) %i 315 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)> 316 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions, 317 MLIRContext *context) { 318 auto index = mlir::getAffineConstantExpr(0, context); 319 auto accuExtent = mlir::getAffineConstantExpr(1, context); 320 for (unsigned i = 0; i < dimensions; ++i) { 321 mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context), 322 lowerBound = mlir::getAffineSymbolExpr(i * 3, context), 323 currentExtent = 324 mlir::getAffineSymbolExpr(i * 3 + 1, context), 325 stride = mlir::getAffineSymbolExpr(i * 3 + 2, context), 326 currentPart = (idx * stride - lowerBound) * accuExtent; 327 index = currentPart + index; 328 accuExtent = accuExtent * currentExtent; 329 } 330 return mlir::AffineMap::get(dimensions, dimensions * 3, index); 331 } 332 333 static std::optional<int64_t> constantIntegerLike(const mlir::Value value) { 334 if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>()) 335 if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>()) 336 return stepAttr.getInt(); 337 return {}; 338 } 339 340 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) { 341 if (auto refType = 342 op.getMemref().getType().dyn_cast_or_null<ReferenceType>()) { 343 if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) { 344 return seqType.getEleTy(); 345 } 346 } 347 op.emitError( 348 "AffineLoopConversion: array type in coordinate operation not valid\n"); 349 return mlir::Type(); 350 } 351 352 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape, 353 SmallVectorImpl<mlir::Value> &indexArgs, 354 mlir::PatternRewriter &rewriter) { 355 auto one = rewriter.create<mlir::arith::ConstantOp>( 356 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 357 auto extents = shape.getExtents(); 358 for (auto i = extents.begin(); i < extents.end(); i++) { 359 indexArgs.push_back(one); 360 indexArgs.push_back(*i); 361 indexArgs.push_back(one); 362 } 363 } 364 365 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape, 366 SmallVectorImpl<mlir::Value> &indexArgs, 367 mlir::PatternRewriter &rewriter) { 368 auto one = rewriter.create<mlir::arith::ConstantOp>( 369 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 370 auto extents = shape.getPairs(); 371 for (auto i = extents.begin(); i < extents.end();) { 372 indexArgs.push_back(*i++); 373 indexArgs.push_back(*i++); 374 indexArgs.push_back(one); 375 } 376 } 377 378 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice, 379 SmallVectorImpl<mlir::Value> &indexArgs, 380 mlir::PatternRewriter &rewriter) { 381 auto extents = slice.getTriples(); 382 for (auto i = extents.begin(); i < extents.end();) { 383 indexArgs.push_back(*i++); 384 indexArgs.push_back(*i++); 385 indexArgs.push_back(*i++); 386 } 387 } 388 389 static void populateIndexArgs(fir::ArrayCoorOp acoOp, 390 SmallVectorImpl<mlir::Value> &indexArgs, 391 mlir::PatternRewriter &rewriter) { 392 if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>()) 393 return populateIndexArgs(acoOp, shape, indexArgs, rewriter); 394 if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>()) 395 return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter); 396 if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>()) 397 return populateIndexArgs(acoOp, slice, indexArgs, rewriter); 398 } 399 400 /// Returns affine.apply and fir.convert from array_coor and gendims 401 static std::pair<mlir::AffineApplyOp, fir::ConvertOp> 402 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) { 403 auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>(); 404 auto affineMap = 405 createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext()); 406 SmallVector<mlir::Value> indexArgs; 407 indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end()); 408 409 populateIndexArgs(acoOp, indexArgs, rewriter); 410 411 auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(), 412 affineMap, indexArgs); 413 auto arrayElementType = coordinateArrayElement(acoOp); 414 auto newType = 415 mlir::MemRefType::get({mlir::ShapedType::kDynamic}, arrayElementType); 416 auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, 417 acoOp.getMemref()); 418 return std::make_pair(affineApply, arrayConvert); 419 } 420 421 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) { 422 rewriter.setInsertionPoint(loadOp); 423 auto affineOps = createAffineOps(loadOp.getMemref(), rewriter); 424 rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>( 425 loadOp, affineOps.second.getResult(), affineOps.first.getResult()); 426 } 427 428 static void rewriteStore(fir::StoreOp storeOp, 429 mlir::PatternRewriter &rewriter) { 430 rewriter.setInsertionPoint(storeOp); 431 auto affineOps = createAffineOps(storeOp.getMemref(), rewriter); 432 rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.getValue(), 433 affineOps.second.getResult(), 434 affineOps.first.getResult()); 435 } 436 437 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) { 438 for (auto &bodyOp : block->getOperations()) { 439 if (isa<fir::LoadOp>(bodyOp)) 440 rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter); 441 if (isa<fir::StoreOp>(bodyOp)) 442 rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter); 443 } 444 } 445 446 namespace { 447 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to 448 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir 449 /// loads and stores to affine. 450 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { 451 public: 452 using OpRewritePattern::OpRewritePattern; 453 AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 454 : OpRewritePattern(context), functionAnalysis(afa) {} 455 456 mlir::LogicalResult 457 matchAndRewrite(fir::DoLoopOp loop, 458 mlir::PatternRewriter &rewriter) const override { 459 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n"; 460 loop.dump();); 461 LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = 462 functionAnalysis.getChildLoopAnalysis(loop); 463 auto &loopOps = loop.getBody()->getOperations(); 464 auto loopAndIndex = createAffineFor(loop, rewriter); 465 auto affineFor = loopAndIndex.first; 466 auto inductionVar = loopAndIndex.second; 467 468 rewriter.startRootUpdate(affineFor.getOperation()); 469 affineFor.getBody()->getOperations().splice( 470 std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(), 471 std::prev(loopOps.end())); 472 rewriter.finalizeRootUpdate(affineFor.getOperation()); 473 474 rewriter.startRootUpdate(loop.getOperation()); 475 loop.getInductionVar().replaceAllUsesWith(inductionVar); 476 rewriter.finalizeRootUpdate(loop.getOperation()); 477 478 rewriteMemoryOps(affineFor.getBody(), rewriter); 479 480 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n"; 481 affineFor.dump();); 482 rewriter.replaceOp(loop, affineFor.getOperation()->getResults()); 483 return success(); 484 } 485 486 private: 487 std::pair<mlir::AffineForOp, mlir::Value> 488 createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 489 if (auto constantStep = constantIntegerLike(op.getStep())) 490 if (*constantStep > 0) 491 return positiveConstantStep(op, *constantStep, rewriter); 492 return genericBounds(op, rewriter); 493 } 494 495 // when step for the loop is positive compile time constant 496 std::pair<mlir::AffineForOp, mlir::Value> 497 positiveConstantStep(fir::DoLoopOp op, int64_t step, 498 mlir::PatternRewriter &rewriter) const { 499 auto affineFor = rewriter.create<mlir::AffineForOp>( 500 op.getLoc(), ValueRange(op.getLowerBound()), 501 mlir::AffineMap::get(0, 1, 502 mlir::getAffineSymbolExpr(0, op.getContext())), 503 ValueRange(op.getUpperBound()), 504 mlir::AffineMap::get(0, 1, 505 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 506 step); 507 return std::make_pair(affineFor, affineFor.getInductionVar()); 508 } 509 510 std::pair<mlir::AffineForOp, mlir::Value> 511 genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 512 auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext()); 513 auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext()); 514 auto step = mlir::getAffineSymbolExpr(2, op.getContext()); 515 mlir::AffineMap upperBoundMap = mlir::AffineMap::get( 516 0, 3, (upperBound - lowerBound + step).floorDiv(step)); 517 auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>( 518 op.getLoc(), upperBoundMap, 519 ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()})); 520 auto actualIndexMap = mlir::AffineMap::get( 521 1, 2, 522 (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) * 523 mlir::getAffineSymbolExpr(1, op.getContext())); 524 525 auto affineFor = rewriter.create<mlir::AffineForOp>( 526 op.getLoc(), ValueRange(), 527 AffineMap::getConstantMap(0, op.getContext()), 528 genericUpperBound.getResult(), 529 mlir::AffineMap::get(0, 1, 530 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 531 1); 532 rewriter.setInsertionPointToStart(affineFor.getBody()); 533 auto actualIndex = rewriter.create<mlir::AffineApplyOp>( 534 op.getLoc(), actualIndexMap, 535 ValueRange( 536 {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()})); 537 return std::make_pair(affineFor, actualIndex.getResult()); 538 } 539 540 AffineFunctionAnalysis &functionAnalysis; 541 }; 542 543 /// Convert `fir.if` to `affine.if`. 544 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> { 545 public: 546 using OpRewritePattern::OpRewritePattern; 547 AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 548 : OpRewritePattern(context) {} 549 mlir::LogicalResult 550 matchAndRewrite(fir::IfOp op, 551 mlir::PatternRewriter &rewriter) const override { 552 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n"; 553 op.dump();); 554 auto &ifOps = op.getThenRegion().front().getOperations(); 555 auto affineCondition = AffineIfCondition(op.getCondition()); 556 if (!affineCondition.hasIntegerSet()) { 557 LLVM_DEBUG( 558 llvm::dbgs() 559 << "AffineIfConversion: couldn't calculate affine condition\n";); 560 return failure(); 561 } 562 auto affineIf = rewriter.create<mlir::AffineIfOp>( 563 op.getLoc(), affineCondition.getIntegerSet(), 564 affineCondition.getAffineArgs(), !op.getElseRegion().empty()); 565 rewriter.startRootUpdate(affineIf); 566 affineIf.getThenBlock()->getOperations().splice( 567 std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(), 568 std::prev(ifOps.end())); 569 if (!op.getElseRegion().empty()) { 570 auto &otherOps = op.getElseRegion().front().getOperations(); 571 affineIf.getElseBlock()->getOperations().splice( 572 std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(), 573 std::prev(otherOps.end())); 574 } 575 rewriter.finalizeRootUpdate(affineIf); 576 rewriteMemoryOps(affineIf.getBody(), rewriter); 577 578 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n"; 579 affineIf.dump();); 580 rewriter.replaceOp(op, affineIf.getOperation()->getResults()); 581 return success(); 582 } 583 }; 584 585 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases 586 /// where such a promotion is possible. 587 class AffineDialectPromotion 588 : public fir::impl::AffineDialectPromotionBase<AffineDialectPromotion> { 589 public: 590 void runOnOperation() override { 591 592 auto *context = &getContext(); 593 auto function = getOperation(); 594 markAllAnalysesPreserved(); 595 auto functionAnalysis = AffineFunctionAnalysis(function); 596 mlir::RewritePatternSet patterns(context); 597 patterns.insert<AffineIfConversion>(context, functionAnalysis); 598 patterns.insert<AffineLoopConversion>(context, functionAnalysis); 599 mlir::ConversionTarget target = *context; 600 target.addLegalDialect<mlir::AffineDialect, FIROpsDialect, 601 mlir::scf::SCFDialect, mlir::arith::ArithDialect, 602 mlir::func::FuncDialect>(); 603 target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) { 604 return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); 605 }); 606 target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis]( 607 fir::DoLoopOp op) { 608 return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine()); 609 }); 610 611 LLVM_DEBUG(llvm::dbgs() 612 << "AffineDialectPromotion: running promotion on: \n"; 613 function.print(llvm::dbgs());); 614 // apply the patterns 615 if (mlir::failed(mlir::applyPartialConversion(function, target, 616 std::move(patterns)))) { 617 mlir::emitError(mlir::UnknownLoc::get(context), 618 "error in converting to affine dialect\n"); 619 signalPassFailure(); 620 } 621 } 622 }; 623 } // namespace 624 625 /// Convert FIR loop constructs to the Affine dialect 626 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() { 627 return std::make_unique<AffineDialectPromotion>(); 628 } 629