1 //===-- AffinePromotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation is a prototype that promote FIR loops operations 10 // to affine dialect operations. 11 // It is not part of the production pipeline and would need more work in order 12 // to be used in production. 13 // More information can be found in this presentation: 14 // https://slides.com/rajanwalia/deck 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "flang/Optimizer/Dialect/FIRDialect.h" 19 #include "flang/Optimizer/Dialect/FIROps.h" 20 #include "flang/Optimizer/Dialect/FIRType.h" 21 #include "flang/Optimizer/Transforms/Passes.h" 22 #include "mlir/Dialect/Affine/IR/AffineOps.h" 23 #include "mlir/Dialect/Func/IR/FuncOps.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/IR/BuiltinAttributes.h" 26 #include "mlir/IR/IntegerSet.h" 27 #include "mlir/IR/Visitors.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/DenseMap.h" 30 #include "llvm/Support/Debug.h" 31 #include <optional> 32 33 namespace fir { 34 #define GEN_PASS_DEF_AFFINEDIALECTPROMOTION 35 #include "flang/Optimizer/Transforms/Passes.h.inc" 36 } // namespace fir 37 38 #define DEBUG_TYPE "flang-affine-promotion" 39 40 using namespace fir; 41 using namespace mlir; 42 43 namespace { 44 struct AffineLoopAnalysis; 45 struct AffineIfAnalysis; 46 47 /// Stores analysis objects for all loops and if operations inside a function 48 /// these analysis are used twice, first for marking operations for rewrite and 49 /// second when doing rewrite. 50 struct AffineFunctionAnalysis { 51 explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) { 52 for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>()) 53 loopAnalysisMap.try_emplace(op, op, *this); 54 } 55 56 AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const; 57 58 AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const; 59 60 llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap; 61 llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap; 62 }; 63 } // namespace 64 65 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) { 66 if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) { 67 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp())) 68 return true; 69 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a " 70 "loop induction variable (owner not loopOp)\n"; 71 op->dump()); 72 return false; 73 } 74 LLVM_DEBUG( 75 llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop " 76 "induction variable (not a block argument)\n"; 77 op->dump(); coordinate.getDefiningOp()->dump()); 78 return false; 79 } 80 81 namespace { 82 struct AffineLoopAnalysis { 83 AffineLoopAnalysis() = default; 84 85 explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa) 86 : legality(analyzeLoop(op, afa)) {} 87 88 bool canPromoteToAffine() { return legality; } 89 90 private: 91 bool analyzeBody(fir::DoLoopOp loopOperation, 92 AffineFunctionAnalysis &functionAnalysis) { 93 for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) { 94 auto analysis = functionAnalysis.loopAnalysisMap 95 .try_emplace(loopOp, loopOp, functionAnalysis) 96 .first->getSecond(); 97 if (!analysis.canPromoteToAffine()) 98 return false; 99 } 100 for (auto ifOp : loopOperation.getOps<fir::IfOp>()) 101 functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis); 102 return true; 103 } 104 105 bool analyzeLoop(fir::DoLoopOp loopOperation, 106 AffineFunctionAnalysis &functionAnalysis) { 107 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump();); 108 return analyzeMemoryAccess(loopOperation) && 109 analyzeBody(loopOperation, functionAnalysis); 110 } 111 112 bool analyzeReference(mlir::Value memref, mlir::Operation *op) { 113 if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) { 114 if (acoOp.getMemref().getType().isa<fir::BoxType>()) { 115 // TODO: Look if and how fir.box can be promoted to affine. 116 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, " 117 "array memory operation uses fir.box\n"; 118 op->dump(); acoOp.dump();); 119 return false; 120 } 121 bool canPromote = true; 122 for (auto coordinate : acoOp.getIndices()) 123 canPromote = canPromote && analyzeCoordinate(coordinate, op); 124 return canPromote; 125 } 126 if (auto coOp = memref.getDefiningOp<CoordinateOp>()) { 127 LLVM_DEBUG(llvm::dbgs() 128 << "AffineLoopAnalysis: cannot promote loop, " 129 "array memory operation uses non ArrayCoorOp\n"; 130 op->dump(); coOp.dump();); 131 132 return false; 133 } 134 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory " 135 "reference for array load\n"; 136 op->dump();); 137 return false; 138 } 139 140 bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) { 141 for (auto loadOp : loopOperation.getOps<fir::LoadOp>()) 142 if (!analyzeReference(loadOp.getMemref(), loadOp)) 143 return false; 144 for (auto storeOp : loopOperation.getOps<fir::StoreOp>()) 145 if (!analyzeReference(storeOp.getMemref(), storeOp)) 146 return false; 147 return true; 148 } 149 150 bool legality{}; 151 }; 152 } // namespace 153 154 AffineLoopAnalysis 155 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const { 156 auto it = loopAnalysisMap.find_as(op); 157 if (it == loopAnalysisMap.end()) { 158 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 159 op.dump();); 160 op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n"); 161 return {}; 162 } 163 return it->getSecond(); 164 } 165 166 namespace { 167 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the 168 /// final number of symbols and dimensions of the affine map. Integer set if 169 /// possible is in Optional IntegerSet. 170 struct AffineIfCondition { 171 using MaybeAffineExpr = std::optional<mlir::AffineExpr>; 172 173 explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) { 174 if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>()) 175 fromCmpIOp(condDef); 176 } 177 178 bool hasIntegerSet() const { return integerSet.has_value(); } 179 180 mlir::IntegerSet getIntegerSet() const { 181 assert(hasIntegerSet() && "integer set is missing"); 182 return *integerSet; 183 } 184 185 mlir::ValueRange getAffineArgs() const { return affineArgs; } 186 187 private: 188 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs, 189 mlir::Value rhs) { 190 return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs)); 191 } 192 193 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs, 194 MaybeAffineExpr rhs) { 195 if (lhs && rhs) 196 return mlir::getAffineBinaryOpExpr(kind, *lhs, *rhs); 197 return {}; 198 } 199 200 MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; } 201 202 MaybeAffineExpr toAffineExpr(int64_t value) { 203 return {mlir::getAffineConstantExpr(value, firCondition.getContext())}; 204 } 205 206 /// Returns an AffineExpr if it is a result of operations that can be done 207 /// in an affine expression, this includes -, +, *, rem, constant. 208 /// block arguments of a loopOp or forOp are used as dimensions 209 MaybeAffineExpr toAffineExpr(mlir::Value value) { 210 if (auto op = value.getDefiningOp<mlir::arith::SubIOp>()) 211 return affineBinaryOp( 212 mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()), 213 affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()), 214 toAffineExpr(-1))); 215 if (auto op = value.getDefiningOp<mlir::arith::AddIOp>()) 216 return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(), 217 op.getRhs()); 218 if (auto op = value.getDefiningOp<mlir::arith::MulIOp>()) 219 return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(), 220 op.getRhs()); 221 if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>()) 222 return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(), 223 op.getRhs()); 224 if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>()) 225 if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>()) 226 return toAffineExpr(intConstant.getInt()); 227 if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) { 228 affineArgs.push_back(value); 229 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) || 230 isa<mlir::affine::AffineForOp>(blockArg.getOwner()->getParentOp())) 231 return {mlir::getAffineDimExpr(dimCount++, value.getContext())}; 232 return {mlir::getAffineSymbolExpr(symCount++, value.getContext())}; 233 } 234 return {}; 235 } 236 237 void fromCmpIOp(mlir::arith::CmpIOp cmpOp) { 238 auto lhsAffine = toAffineExpr(cmpOp.getLhs()); 239 auto rhsAffine = toAffineExpr(cmpOp.getRhs()); 240 if (!lhsAffine || !rhsAffine) 241 return; 242 auto constraintPair = 243 constraint(cmpOp.getPredicate(), *rhsAffine - *lhsAffine); 244 if (!constraintPair) 245 return; 246 integerSet = mlir::IntegerSet::get( 247 dimCount, symCount, {constraintPair->first}, {constraintPair->second}); 248 } 249 250 std::optional<std::pair<AffineExpr, bool>> 251 constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) { 252 switch (predicate) { 253 case mlir::arith::CmpIPredicate::slt: 254 return {std::make_pair(basic - 1, false)}; 255 case mlir::arith::CmpIPredicate::sle: 256 return {std::make_pair(basic, false)}; 257 case mlir::arith::CmpIPredicate::sgt: 258 return {std::make_pair(1 - basic, false)}; 259 case mlir::arith::CmpIPredicate::sge: 260 return {std::make_pair(0 - basic, false)}; 261 case mlir::arith::CmpIPredicate::eq: 262 return {std::make_pair(basic, true)}; 263 default: 264 return {}; 265 } 266 } 267 268 llvm::SmallVector<mlir::Value> affineArgs; 269 std::optional<mlir::IntegerSet> integerSet; 270 mlir::Value firCondition; 271 unsigned symCount{0u}; 272 unsigned dimCount{0u}; 273 }; 274 } // namespace 275 276 namespace { 277 /// Analysis for affine promotion of fir.if 278 struct AffineIfAnalysis { 279 AffineIfAnalysis() = default; 280 281 explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa) 282 : legality(analyzeIf(op, afa)) {} 283 284 bool canPromoteToAffine() { return legality; } 285 286 private: 287 bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) { 288 if (op.getNumResults() == 0) 289 return true; 290 LLVM_DEBUG(llvm::dbgs() 291 << "AffineIfAnalysis: not promoting as op has results\n";); 292 return false; 293 } 294 295 bool legality{}; 296 }; 297 } // namespace 298 299 AffineIfAnalysis 300 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const { 301 auto it = ifAnalysisMap.find_as(op); 302 if (it == ifAnalysisMap.end()) { 303 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 304 op.dump();); 305 op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n"); 306 return {}; 307 } 308 return it->getSecond(); 309 } 310 311 /// AffineMap rewriting fir.array_coor operation to affine apply, 312 /// %dim = fir.gendim %lowerBound, %upperBound, %stride 313 /// %a = fir.array_coor %arr(%dim) %i 314 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)> 315 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions, 316 MLIRContext *context) { 317 auto index = mlir::getAffineConstantExpr(0, context); 318 auto accuExtent = mlir::getAffineConstantExpr(1, context); 319 for (unsigned i = 0; i < dimensions; ++i) { 320 mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context), 321 lowerBound = mlir::getAffineSymbolExpr(i * 3, context), 322 currentExtent = 323 mlir::getAffineSymbolExpr(i * 3 + 1, context), 324 stride = mlir::getAffineSymbolExpr(i * 3 + 2, context), 325 currentPart = (idx * stride - lowerBound) * accuExtent; 326 index = currentPart + index; 327 accuExtent = accuExtent * currentExtent; 328 } 329 return mlir::AffineMap::get(dimensions, dimensions * 3, index); 330 } 331 332 static std::optional<int64_t> constantIntegerLike(const mlir::Value value) { 333 if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>()) 334 if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>()) 335 return stepAttr.getInt(); 336 return {}; 337 } 338 339 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) { 340 if (auto refType = 341 op.getMemref().getType().dyn_cast_or_null<ReferenceType>()) { 342 if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) { 343 return seqType.getEleTy(); 344 } 345 } 346 op.emitError( 347 "AffineLoopConversion: array type in coordinate operation not valid\n"); 348 return mlir::Type(); 349 } 350 351 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape, 352 SmallVectorImpl<mlir::Value> &indexArgs, 353 mlir::PatternRewriter &rewriter) { 354 auto one = rewriter.create<mlir::arith::ConstantOp>( 355 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 356 auto extents = shape.getExtents(); 357 for (auto i = extents.begin(); i < extents.end(); i++) { 358 indexArgs.push_back(one); 359 indexArgs.push_back(*i); 360 indexArgs.push_back(one); 361 } 362 } 363 364 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape, 365 SmallVectorImpl<mlir::Value> &indexArgs, 366 mlir::PatternRewriter &rewriter) { 367 auto one = rewriter.create<mlir::arith::ConstantOp>( 368 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 369 auto extents = shape.getPairs(); 370 for (auto i = extents.begin(); i < extents.end();) { 371 indexArgs.push_back(*i++); 372 indexArgs.push_back(*i++); 373 indexArgs.push_back(one); 374 } 375 } 376 377 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice, 378 SmallVectorImpl<mlir::Value> &indexArgs, 379 mlir::PatternRewriter &rewriter) { 380 auto extents = slice.getTriples(); 381 for (auto i = extents.begin(); i < extents.end();) { 382 indexArgs.push_back(*i++); 383 indexArgs.push_back(*i++); 384 indexArgs.push_back(*i++); 385 } 386 } 387 388 static void populateIndexArgs(fir::ArrayCoorOp acoOp, 389 SmallVectorImpl<mlir::Value> &indexArgs, 390 mlir::PatternRewriter &rewriter) { 391 if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>()) 392 return populateIndexArgs(acoOp, shape, indexArgs, rewriter); 393 if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>()) 394 return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter); 395 if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>()) 396 return populateIndexArgs(acoOp, slice, indexArgs, rewriter); 397 } 398 399 /// Returns affine.apply and fir.convert from array_coor and gendims 400 static std::pair<affine::AffineApplyOp, fir::ConvertOp> 401 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) { 402 auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>(); 403 auto affineMap = 404 createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext()); 405 SmallVector<mlir::Value> indexArgs; 406 indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end()); 407 408 populateIndexArgs(acoOp, indexArgs, rewriter); 409 410 auto affineApply = rewriter.create<affine::AffineApplyOp>( 411 acoOp.getLoc(), affineMap, indexArgs); 412 auto arrayElementType = coordinateArrayElement(acoOp); 413 auto newType = 414 mlir::MemRefType::get({mlir::ShapedType::kDynamic}, arrayElementType); 415 auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, 416 acoOp.getMemref()); 417 return std::make_pair(affineApply, arrayConvert); 418 } 419 420 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) { 421 rewriter.setInsertionPoint(loadOp); 422 auto affineOps = createAffineOps(loadOp.getMemref(), rewriter); 423 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( 424 loadOp, affineOps.second.getResult(), affineOps.first.getResult()); 425 } 426 427 static void rewriteStore(fir::StoreOp storeOp, 428 mlir::PatternRewriter &rewriter) { 429 rewriter.setInsertionPoint(storeOp); 430 auto affineOps = createAffineOps(storeOp.getMemref(), rewriter); 431 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( 432 storeOp, storeOp.getValue(), affineOps.second.getResult(), 433 affineOps.first.getResult()); 434 } 435 436 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) { 437 for (auto &bodyOp : block->getOperations()) { 438 if (isa<fir::LoadOp>(bodyOp)) 439 rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter); 440 if (isa<fir::StoreOp>(bodyOp)) 441 rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter); 442 } 443 } 444 445 namespace { 446 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to 447 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir 448 /// loads and stores to affine. 449 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { 450 public: 451 using OpRewritePattern::OpRewritePattern; 452 AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 453 : OpRewritePattern(context), functionAnalysis(afa) {} 454 455 mlir::LogicalResult 456 matchAndRewrite(fir::DoLoopOp loop, 457 mlir::PatternRewriter &rewriter) const override { 458 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n"; 459 loop.dump();); 460 LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = 461 functionAnalysis.getChildLoopAnalysis(loop); 462 auto &loopOps = loop.getBody()->getOperations(); 463 auto loopAndIndex = createAffineFor(loop, rewriter); 464 auto affineFor = loopAndIndex.first; 465 auto inductionVar = loopAndIndex.second; 466 467 rewriter.startRootUpdate(affineFor.getOperation()); 468 affineFor.getBody()->getOperations().splice( 469 std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(), 470 std::prev(loopOps.end())); 471 rewriter.finalizeRootUpdate(affineFor.getOperation()); 472 473 rewriter.startRootUpdate(loop.getOperation()); 474 loop.getInductionVar().replaceAllUsesWith(inductionVar); 475 rewriter.finalizeRootUpdate(loop.getOperation()); 476 477 rewriteMemoryOps(affineFor.getBody(), rewriter); 478 479 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n"; 480 affineFor.dump();); 481 rewriter.replaceOp(loop, affineFor.getOperation()->getResults()); 482 return success(); 483 } 484 485 private: 486 std::pair<affine::AffineForOp, mlir::Value> 487 createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 488 if (auto constantStep = constantIntegerLike(op.getStep())) 489 if (*constantStep > 0) 490 return positiveConstantStep(op, *constantStep, rewriter); 491 return genericBounds(op, rewriter); 492 } 493 494 // when step for the loop is positive compile time constant 495 std::pair<affine::AffineForOp, mlir::Value> 496 positiveConstantStep(fir::DoLoopOp op, int64_t step, 497 mlir::PatternRewriter &rewriter) const { 498 auto affineFor = rewriter.create<affine::AffineForOp>( 499 op.getLoc(), ValueRange(op.getLowerBound()), 500 mlir::AffineMap::get(0, 1, 501 mlir::getAffineSymbolExpr(0, op.getContext())), 502 ValueRange(op.getUpperBound()), 503 mlir::AffineMap::get(0, 1, 504 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 505 step); 506 return std::make_pair(affineFor, affineFor.getInductionVar()); 507 } 508 509 std::pair<affine::AffineForOp, mlir::Value> 510 genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 511 auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext()); 512 auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext()); 513 auto step = mlir::getAffineSymbolExpr(2, op.getContext()); 514 mlir::AffineMap upperBoundMap = mlir::AffineMap::get( 515 0, 3, (upperBound - lowerBound + step).floorDiv(step)); 516 auto genericUpperBound = rewriter.create<affine::AffineApplyOp>( 517 op.getLoc(), upperBoundMap, 518 ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()})); 519 auto actualIndexMap = mlir::AffineMap::get( 520 1, 2, 521 (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) * 522 mlir::getAffineSymbolExpr(1, op.getContext())); 523 524 auto affineFor = rewriter.create<affine::AffineForOp>( 525 op.getLoc(), ValueRange(), 526 AffineMap::getConstantMap(0, op.getContext()), 527 genericUpperBound.getResult(), 528 mlir::AffineMap::get(0, 1, 529 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 530 1); 531 rewriter.setInsertionPointToStart(affineFor.getBody()); 532 auto actualIndex = rewriter.create<affine::AffineApplyOp>( 533 op.getLoc(), actualIndexMap, 534 ValueRange( 535 {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()})); 536 return std::make_pair(affineFor, actualIndex.getResult()); 537 } 538 539 AffineFunctionAnalysis &functionAnalysis; 540 }; 541 542 /// Convert `fir.if` to `affine.if`. 543 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> { 544 public: 545 using OpRewritePattern::OpRewritePattern; 546 AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 547 : OpRewritePattern(context) {} 548 mlir::LogicalResult 549 matchAndRewrite(fir::IfOp op, 550 mlir::PatternRewriter &rewriter) const override { 551 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n"; 552 op.dump();); 553 auto &ifOps = op.getThenRegion().front().getOperations(); 554 auto affineCondition = AffineIfCondition(op.getCondition()); 555 if (!affineCondition.hasIntegerSet()) { 556 LLVM_DEBUG( 557 llvm::dbgs() 558 << "AffineIfConversion: couldn't calculate affine condition\n";); 559 return failure(); 560 } 561 auto affineIf = rewriter.create<affine::AffineIfOp>( 562 op.getLoc(), affineCondition.getIntegerSet(), 563 affineCondition.getAffineArgs(), !op.getElseRegion().empty()); 564 rewriter.startRootUpdate(affineIf); 565 affineIf.getThenBlock()->getOperations().splice( 566 std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(), 567 std::prev(ifOps.end())); 568 if (!op.getElseRegion().empty()) { 569 auto &otherOps = op.getElseRegion().front().getOperations(); 570 affineIf.getElseBlock()->getOperations().splice( 571 std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(), 572 std::prev(otherOps.end())); 573 } 574 rewriter.finalizeRootUpdate(affineIf); 575 rewriteMemoryOps(affineIf.getBody(), rewriter); 576 577 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n"; 578 affineIf.dump();); 579 rewriter.replaceOp(op, affineIf.getOperation()->getResults()); 580 return success(); 581 } 582 }; 583 584 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases 585 /// where such a promotion is possible. 586 class AffineDialectPromotion 587 : public fir::impl::AffineDialectPromotionBase<AffineDialectPromotion> { 588 public: 589 void runOnOperation() override { 590 591 auto *context = &getContext(); 592 auto function = getOperation(); 593 markAllAnalysesPreserved(); 594 auto functionAnalysis = AffineFunctionAnalysis(function); 595 mlir::RewritePatternSet patterns(context); 596 patterns.insert<AffineIfConversion>(context, functionAnalysis); 597 patterns.insert<AffineLoopConversion>(context, functionAnalysis); 598 mlir::ConversionTarget target = *context; 599 target.addLegalDialect<mlir::affine::AffineDialect, FIROpsDialect, 600 mlir::scf::SCFDialect, mlir::arith::ArithDialect, 601 mlir::func::FuncDialect>(); 602 target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) { 603 return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); 604 }); 605 target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis]( 606 fir::DoLoopOp op) { 607 return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine()); 608 }); 609 610 LLVM_DEBUG(llvm::dbgs() 611 << "AffineDialectPromotion: running promotion on: \n"; 612 function.print(llvm::dbgs());); 613 // apply the patterns 614 if (mlir::failed(mlir::applyPartialConversion(function, target, 615 std::move(patterns)))) { 616 mlir::emitError(mlir::UnknownLoc::get(context), 617 "error in converting to affine dialect\n"); 618 signalPassFailure(); 619 } 620 } 621 }; 622 } // namespace 623 624 /// Convert FIR loop constructs to the Affine dialect 625 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() { 626 return std::make_unique<AffineDialectPromotion>(); 627 } 628