xref: /llvm-project/flang/lib/Optimizer/Transforms/AffinePromotion.cpp (revision c82fb16f580c80b961cb1971df66ce70eaf7d87a)
1 //===-- AffinePromotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation is a prototype that promote FIR loops operations
10 // to affine dialect operations.
11 // It is not part of the production pipeline and would need more work in order
12 // to be used in production.
13 // More information can be found in this presentation:
14 // https://slides.com/rajanwalia/deck
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "PassDetail.h"
19 #include "flang/Optimizer/Dialect/FIRDialect.h"
20 #include "flang/Optimizer/Dialect/FIROps.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/Transforms/Passes.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/IR/BuiltinAttributes.h"
27 #include "mlir/IR/IntegerSet.h"
28 #include "mlir/IR/Visitors.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/Support/Debug.h"
33 
34 #define DEBUG_TYPE "flang-affine-promotion"
35 
36 using namespace fir;
37 using namespace mlir;
38 
39 namespace {
40 struct AffineLoopAnalysis;
41 struct AffineIfAnalysis;
42 
43 /// Stores analysis objects for all loops and if operations inside a function
44 /// these analysis are used twice, first for marking operations for rewrite and
45 /// second when doing rewrite.
46 struct AffineFunctionAnalysis {
47   explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) {
48     for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
49       loopAnalysisMap.try_emplace(op, op, *this);
50   }
51 
52   AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
53 
54   AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const;
55 
56   llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap;
57   llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap;
58 };
59 } // namespace
60 
61 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) {
62   if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) {
63     if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()))
64       return true;
65     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a "
66                                "loop induction variable (owner not loopOp)\n";
67                op->dump());
68     return false;
69   }
70   LLVM_DEBUG(
71       llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop "
72                       "induction variable (not a block argument)\n";
73       op->dump(); coordinate.getDefiningOp()->dump());
74   return false;
75 }
76 
77 namespace {
78 struct AffineLoopAnalysis {
79   AffineLoopAnalysis() = default;
80 
81   explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa)
82       : legality(analyzeLoop(op, afa)) {}
83 
84   bool canPromoteToAffine() { return legality; }
85 
86 private:
87   bool analyzeBody(fir::DoLoopOp loopOperation,
88                    AffineFunctionAnalysis &functionAnalysis) {
89     for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) {
90       auto analysis = functionAnalysis.loopAnalysisMap
91                           .try_emplace(loopOp, loopOp, functionAnalysis)
92                           .first->getSecond();
93       if (!analysis.canPromoteToAffine())
94         return false;
95     }
96     for (auto ifOp : loopOperation.getOps<fir::IfOp>())
97       functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis);
98     return true;
99   }
100 
101   bool analyzeLoop(fir::DoLoopOp loopOperation,
102                    AffineFunctionAnalysis &functionAnalysis) {
103     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
104     return analyzeMemoryAccess(loopOperation) &&
105            analyzeBody(loopOperation, functionAnalysis);
106   }
107 
108   bool analyzeReference(mlir::Value memref, mlir::Operation *op) {
109     if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) {
110       if (acoOp.getMemref().getType().isa<fir::BoxType>()) {
111         // TODO: Look if and how fir.box can be promoted to affine.
112         LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, "
113                                    "array memory operation uses fir.box\n";
114                    op->dump(); acoOp.dump(););
115         return false;
116       }
117       bool canPromote = true;
118       for (auto coordinate : acoOp.getIndices())
119         canPromote = canPromote && analyzeCoordinate(coordinate, op);
120       return canPromote;
121     }
122     if (auto coOp = memref.getDefiningOp<CoordinateOp>()) {
123       LLVM_DEBUG(llvm::dbgs()
124                      << "AffineLoopAnalysis: cannot promote loop, "
125                         "array memory operation uses non ArrayCoorOp\n";
126                  op->dump(); coOp.dump(););
127 
128       return false;
129     }
130     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory "
131                                "reference for array load\n";
132                op->dump(););
133     return false;
134   }
135 
136   bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) {
137     for (auto loadOp : loopOperation.getOps<fir::LoadOp>())
138       if (!analyzeReference(loadOp.getMemref(), loadOp))
139         return false;
140     for (auto storeOp : loopOperation.getOps<fir::StoreOp>())
141       if (!analyzeReference(storeOp.getMemref(), storeOp))
142         return false;
143     return true;
144   }
145 
146   bool legality{};
147 };
148 } // namespace
149 
150 AffineLoopAnalysis
151 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const {
152   auto it = loopAnalysisMap.find_as(op);
153   if (it == loopAnalysisMap.end()) {
154     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
155                op.dump(););
156     op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n");
157     return {};
158   }
159   return it->getSecond();
160 }
161 
162 namespace {
163 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the
164 /// final number of symbols and dimensions of the affine map. Integer set if
165 /// possible is in Optional IntegerSet.
166 struct AffineIfCondition {
167   using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>;
168 
169   explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) {
170     if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>())
171       fromCmpIOp(condDef);
172   }
173 
174   bool hasIntegerSet() const { return integerSet.has_value(); }
175 
176   mlir::IntegerSet getIntegerSet() const {
177     assert(hasIntegerSet() && "integer set is missing");
178     return integerSet.getValue();
179   }
180 
181   mlir::ValueRange getAffineArgs() const { return affineArgs; }
182 
183 private:
184   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs,
185                                  mlir::Value rhs) {
186     return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs));
187   }
188 
189   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs,
190                                  MaybeAffineExpr rhs) {
191     if (lhs && rhs)
192       return mlir::getAffineBinaryOpExpr(kind, lhs.getValue(), rhs.getValue());
193     return {};
194   }
195 
196   MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; }
197 
198   MaybeAffineExpr toAffineExpr(int64_t value) {
199     return {mlir::getAffineConstantExpr(value, firCondition.getContext())};
200   }
201 
202   /// Returns an AffineExpr if it is a result of operations that can be done
203   /// in an affine expression, this includes -, +, *, rem, constant.
204   /// block arguments of a loopOp or forOp are used as dimensions
205   MaybeAffineExpr toAffineExpr(mlir::Value value) {
206     if (auto op = value.getDefiningOp<mlir::arith::SubIOp>())
207       return affineBinaryOp(
208           mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()),
209           affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()),
210                          toAffineExpr(-1)));
211     if (auto op = value.getDefiningOp<mlir::arith::AddIOp>())
212       return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(),
213                             op.getRhs());
214     if (auto op = value.getDefiningOp<mlir::arith::MulIOp>())
215       return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(),
216                             op.getRhs());
217     if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>())
218       return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(),
219                             op.getRhs());
220     if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>())
221       if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>())
222         return toAffineExpr(intConstant.getInt());
223     if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
224       affineArgs.push_back(value);
225       if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) ||
226           isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp()))
227         return {mlir::getAffineDimExpr(dimCount++, value.getContext())};
228       return {mlir::getAffineSymbolExpr(symCount++, value.getContext())};
229     }
230     return {};
231   }
232 
233   void fromCmpIOp(mlir::arith::CmpIOp cmpOp) {
234     auto lhsAffine = toAffineExpr(cmpOp.getLhs());
235     auto rhsAffine = toAffineExpr(cmpOp.getRhs());
236     if (!lhsAffine || !rhsAffine)
237       return;
238     auto constraintPair = constraint(
239         cmpOp.getPredicate(), rhsAffine.getValue() - lhsAffine.getValue());
240     if (!constraintPair)
241       return;
242     integerSet = mlir::IntegerSet::get(dimCount, symCount,
243                                        {constraintPair.getValue().first},
244                                        {constraintPair.getValue().second});
245     return;
246   }
247 
248   llvm::Optional<std::pair<AffineExpr, bool>>
249   constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) {
250     switch (predicate) {
251     case mlir::arith::CmpIPredicate::slt:
252       return {std::make_pair(basic - 1, false)};
253     case mlir::arith::CmpIPredicate::sle:
254       return {std::make_pair(basic, false)};
255     case mlir::arith::CmpIPredicate::sgt:
256       return {std::make_pair(1 - basic, false)};
257     case mlir::arith::CmpIPredicate::sge:
258       return {std::make_pair(0 - basic, false)};
259     case mlir::arith::CmpIPredicate::eq:
260       return {std::make_pair(basic, true)};
261     default:
262       return {};
263     }
264   }
265 
266   llvm::SmallVector<mlir::Value> affineArgs;
267   llvm::Optional<mlir::IntegerSet> integerSet;
268   mlir::Value firCondition;
269   unsigned symCount{0u};
270   unsigned dimCount{0u};
271 };
272 } // namespace
273 
274 namespace {
275 /// Analysis for affine promotion of fir.if
276 struct AffineIfAnalysis {
277   AffineIfAnalysis() = default;
278 
279   explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa)
280       : legality(analyzeIf(op, afa)) {}
281 
282   bool canPromoteToAffine() { return legality; }
283 
284 private:
285   bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) {
286     if (op.getNumResults() == 0)
287       return true;
288     LLVM_DEBUG(llvm::dbgs()
289                    << "AffineIfAnalysis: not promoting as op has results\n";);
290     return false;
291   }
292 
293   bool legality{};
294 };
295 } // namespace
296 
297 AffineIfAnalysis
298 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const {
299   auto it = ifAnalysisMap.find_as(op);
300   if (it == ifAnalysisMap.end()) {
301     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
302                op.dump(););
303     op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n");
304     return {};
305   }
306   return it->getSecond();
307 }
308 
309 /// AffineMap rewriting fir.array_coor operation to affine apply,
310 /// %dim = fir.gendim %lowerBound, %upperBound, %stride
311 /// %a = fir.array_coor %arr(%dim) %i
312 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)>
313 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,
314                                                  MLIRContext *context) {
315   auto index = mlir::getAffineConstantExpr(0, context);
316   auto accuExtent = mlir::getAffineConstantExpr(1, context);
317   for (unsigned i = 0; i < dimensions; ++i) {
318     mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context),
319                      lowerBound = mlir::getAffineSymbolExpr(i * 3, context),
320                      currentExtent =
321                          mlir::getAffineSymbolExpr(i * 3 + 1, context),
322                      stride = mlir::getAffineSymbolExpr(i * 3 + 2, context),
323                      currentPart = (idx * stride - lowerBound) * accuExtent;
324     index = currentPart + index;
325     accuExtent = accuExtent * currentExtent;
326   }
327   return mlir::AffineMap::get(dimensions, dimensions * 3, index);
328 }
329 
330 static Optional<int64_t> constantIntegerLike(const mlir::Value value) {
331   if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>())
332     if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>())
333       return stepAttr.getInt();
334   return {};
335 }
336 
337 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) {
338   if (auto refType =
339           op.getMemref().getType().dyn_cast_or_null<ReferenceType>()) {
340     if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) {
341       return seqType.getEleTy();
342     }
343   }
344   op.emitError(
345       "AffineLoopConversion: array type in coordinate operation not valid\n");
346   return mlir::Type();
347 }
348 
349 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape,
350                               SmallVectorImpl<mlir::Value> &indexArgs,
351                               mlir::PatternRewriter &rewriter) {
352   auto one = rewriter.create<mlir::arith::ConstantOp>(
353       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
354   auto extents = shape.getExtents();
355   for (auto i = extents.begin(); i < extents.end(); i++) {
356     indexArgs.push_back(one);
357     indexArgs.push_back(*i);
358     indexArgs.push_back(one);
359   }
360 }
361 
362 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape,
363                               SmallVectorImpl<mlir::Value> &indexArgs,
364                               mlir::PatternRewriter &rewriter) {
365   auto one = rewriter.create<mlir::arith::ConstantOp>(
366       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
367   auto extents = shape.getPairs();
368   for (auto i = extents.begin(); i < extents.end();) {
369     indexArgs.push_back(*i++);
370     indexArgs.push_back(*i++);
371     indexArgs.push_back(one);
372   }
373 }
374 
375 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice,
376                               SmallVectorImpl<mlir::Value> &indexArgs,
377                               mlir::PatternRewriter &rewriter) {
378   auto extents = slice.getTriples();
379   for (auto i = extents.begin(); i < extents.end();) {
380     indexArgs.push_back(*i++);
381     indexArgs.push_back(*i++);
382     indexArgs.push_back(*i++);
383   }
384 }
385 
386 static void populateIndexArgs(fir::ArrayCoorOp acoOp,
387                               SmallVectorImpl<mlir::Value> &indexArgs,
388                               mlir::PatternRewriter &rewriter) {
389   if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>())
390     return populateIndexArgs(acoOp, shape, indexArgs, rewriter);
391   if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>())
392     return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter);
393   if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>())
394     return populateIndexArgs(acoOp, slice, indexArgs, rewriter);
395   return;
396 }
397 
398 /// Returns affine.apply and fir.convert from array_coor and gendims
399 static std::pair<mlir::AffineApplyOp, fir::ConvertOp>
400 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) {
401   auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>();
402   auto affineMap =
403       createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext());
404   SmallVector<mlir::Value> indexArgs;
405   indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end());
406 
407   populateIndexArgs(acoOp, indexArgs, rewriter);
408 
409   auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
410                                                           affineMap, indexArgs);
411   auto arrayElementType = coordinateArrayElement(acoOp);
412   auto newType = mlir::MemRefType::get({-1}, arrayElementType);
413   auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType,
414                                                       acoOp.getMemref());
415   return std::make_pair(affineApply, arrayConvert);
416 }
417 
418 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) {
419   rewriter.setInsertionPoint(loadOp);
420   auto affineOps = createAffineOps(loadOp.getMemref(), rewriter);
421   rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>(
422       loadOp, affineOps.second.getResult(), affineOps.first.getResult());
423 }
424 
425 static void rewriteStore(fir::StoreOp storeOp,
426                          mlir::PatternRewriter &rewriter) {
427   rewriter.setInsertionPoint(storeOp);
428   auto affineOps = createAffineOps(storeOp.getMemref(), rewriter);
429   rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.getValue(),
430                                                    affineOps.second.getResult(),
431                                                    affineOps.first.getResult());
432 }
433 
434 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
435   for (auto &bodyOp : block->getOperations()) {
436     if (isa<fir::LoadOp>(bodyOp))
437       rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
438     if (isa<fir::StoreOp>(bodyOp))
439       rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
440   }
441 }
442 
443 namespace {
444 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to
445 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir
446 /// loads and stores to affine.
447 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
448 public:
449   using OpRewritePattern::OpRewritePattern;
450   AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
451       : OpRewritePattern(context), functionAnalysis(afa) {}
452 
453   mlir::LogicalResult
454   matchAndRewrite(fir::DoLoopOp loop,
455                   mlir::PatternRewriter &rewriter) const override {
456     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
457                loop.dump(););
458     LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
459         functionAnalysis.getChildLoopAnalysis(loop);
460     auto &loopOps = loop.getBody()->getOperations();
461     auto loopAndIndex = createAffineFor(loop, rewriter);
462     auto affineFor = loopAndIndex.first;
463     auto inductionVar = loopAndIndex.second;
464 
465     rewriter.startRootUpdate(affineFor.getOperation());
466     affineFor.getBody()->getOperations().splice(
467         std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
468         std::prev(loopOps.end()));
469     rewriter.finalizeRootUpdate(affineFor.getOperation());
470 
471     rewriter.startRootUpdate(loop.getOperation());
472     loop.getInductionVar().replaceAllUsesWith(inductionVar);
473     rewriter.finalizeRootUpdate(loop.getOperation());
474 
475     rewriteMemoryOps(affineFor.getBody(), rewriter);
476 
477     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
478                affineFor.dump(););
479     rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
480     return success();
481   }
482 
483 private:
484   std::pair<mlir::AffineForOp, mlir::Value>
485   createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
486     if (auto constantStep = constantIntegerLike(op.getStep()))
487       if (constantStep.getValue() > 0)
488         return positiveConstantStep(op, constantStep.getValue(), rewriter);
489     return genericBounds(op, rewriter);
490   }
491 
492   // when step for the loop is positive compile time constant
493   std::pair<mlir::AffineForOp, mlir::Value>
494   positiveConstantStep(fir::DoLoopOp op, int64_t step,
495                        mlir::PatternRewriter &rewriter) const {
496     auto affineFor = rewriter.create<mlir::AffineForOp>(
497         op.getLoc(), ValueRange(op.getLowerBound()),
498         mlir::AffineMap::get(0, 1,
499                              mlir::getAffineSymbolExpr(0, op.getContext())),
500         ValueRange(op.getUpperBound()),
501         mlir::AffineMap::get(0, 1,
502                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
503         step);
504     return std::make_pair(affineFor, affineFor.getInductionVar());
505   }
506 
507   std::pair<mlir::AffineForOp, mlir::Value>
508   genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
509     auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext());
510     auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext());
511     auto step = mlir::getAffineSymbolExpr(2, op.getContext());
512     mlir::AffineMap upperBoundMap = mlir::AffineMap::get(
513         0, 3, (upperBound - lowerBound + step).floorDiv(step));
514     auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>(
515         op.getLoc(), upperBoundMap,
516         ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()}));
517     auto actualIndexMap = mlir::AffineMap::get(
518         1, 2,
519         (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) *
520             mlir::getAffineSymbolExpr(1, op.getContext()));
521 
522     auto affineFor = rewriter.create<mlir::AffineForOp>(
523         op.getLoc(), ValueRange(),
524         AffineMap::getConstantMap(0, op.getContext()),
525         genericUpperBound.getResult(),
526         mlir::AffineMap::get(0, 1,
527                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
528         1);
529     rewriter.setInsertionPointToStart(affineFor.getBody());
530     auto actualIndex = rewriter.create<mlir::AffineApplyOp>(
531         op.getLoc(), actualIndexMap,
532         ValueRange(
533             {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()}));
534     return std::make_pair(affineFor, actualIndex.getResult());
535   }
536 
537   AffineFunctionAnalysis &functionAnalysis;
538 };
539 
540 /// Convert `fir.if` to `affine.if`.
541 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
542 public:
543   using OpRewritePattern::OpRewritePattern;
544   AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
545       : OpRewritePattern(context) {}
546   mlir::LogicalResult
547   matchAndRewrite(fir::IfOp op,
548                   mlir::PatternRewriter &rewriter) const override {
549     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
550                op.dump(););
551     auto &ifOps = op.getThenRegion().front().getOperations();
552     auto affineCondition = AffineIfCondition(op.getCondition());
553     if (!affineCondition.hasIntegerSet()) {
554       LLVM_DEBUG(
555           llvm::dbgs()
556               << "AffineIfConversion: couldn't calculate affine condition\n";);
557       return failure();
558     }
559     auto affineIf = rewriter.create<mlir::AffineIfOp>(
560         op.getLoc(), affineCondition.getIntegerSet(),
561         affineCondition.getAffineArgs(), !op.getElseRegion().empty());
562     rewriter.startRootUpdate(affineIf);
563     affineIf.getThenBlock()->getOperations().splice(
564         std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
565         std::prev(ifOps.end()));
566     if (!op.getElseRegion().empty()) {
567       auto &otherOps = op.getElseRegion().front().getOperations();
568       affineIf.getElseBlock()->getOperations().splice(
569           std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
570           std::prev(otherOps.end()));
571     }
572     rewriter.finalizeRootUpdate(affineIf);
573     rewriteMemoryOps(affineIf.getBody(), rewriter);
574 
575     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
576                affineIf.dump(););
577     rewriter.replaceOp(op, affineIf.getOperation()->getResults());
578     return success();
579   }
580 };
581 
582 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
583 /// where such a promotion is possible.
584 class AffineDialectPromotion
585     : public AffineDialectPromotionBase<AffineDialectPromotion> {
586 public:
587   void runOnOperation() override {
588 
589     auto *context = &getContext();
590     auto function = getOperation();
591     markAllAnalysesPreserved();
592     auto functionAnalysis = AffineFunctionAnalysis(function);
593     mlir::RewritePatternSet patterns(context);
594     patterns.insert<AffineIfConversion>(context, functionAnalysis);
595     patterns.insert<AffineLoopConversion>(context, functionAnalysis);
596     mlir::ConversionTarget target = *context;
597     target.addLegalDialect<
598         mlir::AffineDialect, FIROpsDialect, mlir::scf::SCFDialect,
599         mlir::arith::ArithmeticDialect, mlir::func::FuncDialect>();
600     target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
601       return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
602     });
603     target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
604                                                fir::DoLoopOp op) {
605       return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
606     });
607 
608     LLVM_DEBUG(llvm::dbgs()
609                    << "AffineDialectPromotion: running promotion on: \n";
610                function.print(llvm::dbgs()););
611     // apply the patterns
612     if (mlir::failed(mlir::applyPartialConversion(function, target,
613                                                   std::move(patterns)))) {
614       mlir::emitError(mlir::UnknownLoc::get(context),
615                       "error in converting to affine dialect\n");
616       signalPassFailure();
617     }
618   }
619 };
620 } // namespace
621 
622 /// Convert FIR loop constructs to the Affine dialect
623 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() {
624   return std::make_unique<AffineDialectPromotion>();
625 }
626