xref: /llvm-project/flang/lib/Optimizer/Transforms/AffinePromotion.cpp (revision 399638f98cdc2405c6a2e85f3cbba175fabaf858)
1 //===-- AffinePromotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation is a prototype that promote FIR loops operations
10 // to affine dialect operations.
11 // It is not part of the production pipeline and would need more work in order
12 // to be used in production.
13 // More information can be found in this presentation:
14 // https://slides.com/rajanwalia/deck
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "flang/Optimizer/Dialect/FIRDialect.h"
19 #include "flang/Optimizer/Dialect/FIROps.h"
20 #include "flang/Optimizer/Dialect/FIRType.h"
21 #include "flang/Optimizer/Transforms/Passes.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/IR/BuiltinAttributes.h"
26 #include "mlir/IR/IntegerSet.h"
27 #include "mlir/IR/Visitors.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/Support/Debug.h"
32 
33 namespace fir {
34 #define GEN_PASS_DEF_AFFINEDIALECTPROMOTION
35 #include "flang/Optimizer/Transforms/Passes.h.inc"
36 } // namespace fir
37 
38 #define DEBUG_TYPE "flang-affine-promotion"
39 
40 using namespace fir;
41 using namespace mlir;
42 
43 namespace {
44 struct AffineLoopAnalysis;
45 struct AffineIfAnalysis;
46 
47 /// Stores analysis objects for all loops and if operations inside a function
48 /// these analysis are used twice, first for marking operations for rewrite and
49 /// second when doing rewrite.
50 struct AffineFunctionAnalysis {
51   explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) {
52     for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
53       loopAnalysisMap.try_emplace(op, op, *this);
54   }
55 
56   AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
57 
58   AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const;
59 
60   llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap;
61   llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap;
62 };
63 } // namespace
64 
65 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) {
66   if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) {
67     if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()))
68       return true;
69     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a "
70                                "loop induction variable (owner not loopOp)\n";
71                op->dump());
72     return false;
73   }
74   LLVM_DEBUG(
75       llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop "
76                       "induction variable (not a block argument)\n";
77       op->dump(); coordinate.getDefiningOp()->dump());
78   return false;
79 }
80 
81 namespace {
82 struct AffineLoopAnalysis {
83   AffineLoopAnalysis() = default;
84 
85   explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa)
86       : legality(analyzeLoop(op, afa)) {}
87 
88   bool canPromoteToAffine() { return legality; }
89 
90 private:
91   bool analyzeBody(fir::DoLoopOp loopOperation,
92                    AffineFunctionAnalysis &functionAnalysis) {
93     for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) {
94       auto analysis = functionAnalysis.loopAnalysisMap
95                           .try_emplace(loopOp, loopOp, functionAnalysis)
96                           .first->getSecond();
97       if (!analysis.canPromoteToAffine())
98         return false;
99     }
100     for (auto ifOp : loopOperation.getOps<fir::IfOp>())
101       functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis);
102     return true;
103   }
104 
105   bool analyzeLoop(fir::DoLoopOp loopOperation,
106                    AffineFunctionAnalysis &functionAnalysis) {
107     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
108     return analyzeMemoryAccess(loopOperation) &&
109            analyzeBody(loopOperation, functionAnalysis);
110   }
111 
112   bool analyzeReference(mlir::Value memref, mlir::Operation *op) {
113     if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) {
114       if (acoOp.getMemref().getType().isa<fir::BoxType>()) {
115         // TODO: Look if and how fir.box can be promoted to affine.
116         LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, "
117                                    "array memory operation uses fir.box\n";
118                    op->dump(); acoOp.dump(););
119         return false;
120       }
121       bool canPromote = true;
122       for (auto coordinate : acoOp.getIndices())
123         canPromote = canPromote && analyzeCoordinate(coordinate, op);
124       return canPromote;
125     }
126     if (auto coOp = memref.getDefiningOp<CoordinateOp>()) {
127       LLVM_DEBUG(llvm::dbgs()
128                      << "AffineLoopAnalysis: cannot promote loop, "
129                         "array memory operation uses non ArrayCoorOp\n";
130                  op->dump(); coOp.dump(););
131 
132       return false;
133     }
134     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory "
135                                "reference for array load\n";
136                op->dump(););
137     return false;
138   }
139 
140   bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) {
141     for (auto loadOp : loopOperation.getOps<fir::LoadOp>())
142       if (!analyzeReference(loadOp.getMemref(), loadOp))
143         return false;
144     for (auto storeOp : loopOperation.getOps<fir::StoreOp>())
145       if (!analyzeReference(storeOp.getMemref(), storeOp))
146         return false;
147     return true;
148   }
149 
150   bool legality{};
151 };
152 } // namespace
153 
154 AffineLoopAnalysis
155 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const {
156   auto it = loopAnalysisMap.find_as(op);
157   if (it == loopAnalysisMap.end()) {
158     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
159                op.dump(););
160     op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n");
161     return {};
162   }
163   return it->getSecond();
164 }
165 
166 namespace {
167 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the
168 /// final number of symbols and dimensions of the affine map. Integer set if
169 /// possible is in Optional IntegerSet.
170 struct AffineIfCondition {
171   using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>;
172 
173   explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) {
174     if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>())
175       fromCmpIOp(condDef);
176   }
177 
178   bool hasIntegerSet() const { return integerSet.has_value(); }
179 
180   mlir::IntegerSet getIntegerSet() const {
181     assert(hasIntegerSet() && "integer set is missing");
182     return integerSet.value();
183   }
184 
185   mlir::ValueRange getAffineArgs() const { return affineArgs; }
186 
187 private:
188   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs,
189                                  mlir::Value rhs) {
190     return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs));
191   }
192 
193   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs,
194                                  MaybeAffineExpr rhs) {
195     if (lhs && rhs)
196       return mlir::getAffineBinaryOpExpr(kind, *lhs, *rhs);
197     return {};
198   }
199 
200   MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; }
201 
202   MaybeAffineExpr toAffineExpr(int64_t value) {
203     return {mlir::getAffineConstantExpr(value, firCondition.getContext())};
204   }
205 
206   /// Returns an AffineExpr if it is a result of operations that can be done
207   /// in an affine expression, this includes -, +, *, rem, constant.
208   /// block arguments of a loopOp or forOp are used as dimensions
209   MaybeAffineExpr toAffineExpr(mlir::Value value) {
210     if (auto op = value.getDefiningOp<mlir::arith::SubIOp>())
211       return affineBinaryOp(
212           mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()),
213           affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()),
214                          toAffineExpr(-1)));
215     if (auto op = value.getDefiningOp<mlir::arith::AddIOp>())
216       return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(),
217                             op.getRhs());
218     if (auto op = value.getDefiningOp<mlir::arith::MulIOp>())
219       return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(),
220                             op.getRhs());
221     if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>())
222       return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(),
223                             op.getRhs());
224     if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>())
225       if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>())
226         return toAffineExpr(intConstant.getInt());
227     if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
228       affineArgs.push_back(value);
229       if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) ||
230           isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp()))
231         return {mlir::getAffineDimExpr(dimCount++, value.getContext())};
232       return {mlir::getAffineSymbolExpr(symCount++, value.getContext())};
233     }
234     return {};
235   }
236 
237   void fromCmpIOp(mlir::arith::CmpIOp cmpOp) {
238     auto lhsAffine = toAffineExpr(cmpOp.getLhs());
239     auto rhsAffine = toAffineExpr(cmpOp.getRhs());
240     if (!lhsAffine || !rhsAffine)
241       return;
242     auto constraintPair =
243         constraint(cmpOp.getPredicate(), *rhsAffine - *lhsAffine);
244     if (!constraintPair)
245       return;
246     integerSet = mlir::IntegerSet::get(
247         dimCount, symCount, {constraintPair->first}, {constraintPair->second});
248   }
249 
250   llvm::Optional<std::pair<AffineExpr, bool>>
251   constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) {
252     switch (predicate) {
253     case mlir::arith::CmpIPredicate::slt:
254       return {std::make_pair(basic - 1, false)};
255     case mlir::arith::CmpIPredicate::sle:
256       return {std::make_pair(basic, false)};
257     case mlir::arith::CmpIPredicate::sgt:
258       return {std::make_pair(1 - basic, false)};
259     case mlir::arith::CmpIPredicate::sge:
260       return {std::make_pair(0 - basic, false)};
261     case mlir::arith::CmpIPredicate::eq:
262       return {std::make_pair(basic, true)};
263     default:
264       return {};
265     }
266   }
267 
268   llvm::SmallVector<mlir::Value> affineArgs;
269   llvm::Optional<mlir::IntegerSet> integerSet;
270   mlir::Value firCondition;
271   unsigned symCount{0u};
272   unsigned dimCount{0u};
273 };
274 } // namespace
275 
276 namespace {
277 /// Analysis for affine promotion of fir.if
278 struct AffineIfAnalysis {
279   AffineIfAnalysis() = default;
280 
281   explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa)
282       : legality(analyzeIf(op, afa)) {}
283 
284   bool canPromoteToAffine() { return legality; }
285 
286 private:
287   bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) {
288     if (op.getNumResults() == 0)
289       return true;
290     LLVM_DEBUG(llvm::dbgs()
291                    << "AffineIfAnalysis: not promoting as op has results\n";);
292     return false;
293   }
294 
295   bool legality{};
296 };
297 } // namespace
298 
299 AffineIfAnalysis
300 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const {
301   auto it = ifAnalysisMap.find_as(op);
302   if (it == ifAnalysisMap.end()) {
303     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
304                op.dump(););
305     op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n");
306     return {};
307   }
308   return it->getSecond();
309 }
310 
311 /// AffineMap rewriting fir.array_coor operation to affine apply,
312 /// %dim = fir.gendim %lowerBound, %upperBound, %stride
313 /// %a = fir.array_coor %arr(%dim) %i
314 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)>
315 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,
316                                                  MLIRContext *context) {
317   auto index = mlir::getAffineConstantExpr(0, context);
318   auto accuExtent = mlir::getAffineConstantExpr(1, context);
319   for (unsigned i = 0; i < dimensions; ++i) {
320     mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context),
321                      lowerBound = mlir::getAffineSymbolExpr(i * 3, context),
322                      currentExtent =
323                          mlir::getAffineSymbolExpr(i * 3 + 1, context),
324                      stride = mlir::getAffineSymbolExpr(i * 3 + 2, context),
325                      currentPart = (idx * stride - lowerBound) * accuExtent;
326     index = currentPart + index;
327     accuExtent = accuExtent * currentExtent;
328   }
329   return mlir::AffineMap::get(dimensions, dimensions * 3, index);
330 }
331 
332 static Optional<int64_t> constantIntegerLike(const mlir::Value value) {
333   if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>())
334     if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>())
335       return stepAttr.getInt();
336   return {};
337 }
338 
339 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) {
340   if (auto refType =
341           op.getMemref().getType().dyn_cast_or_null<ReferenceType>()) {
342     if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) {
343       return seqType.getEleTy();
344     }
345   }
346   op.emitError(
347       "AffineLoopConversion: array type in coordinate operation not valid\n");
348   return mlir::Type();
349 }
350 
351 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape,
352                               SmallVectorImpl<mlir::Value> &indexArgs,
353                               mlir::PatternRewriter &rewriter) {
354   auto one = rewriter.create<mlir::arith::ConstantOp>(
355       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
356   auto extents = shape.getExtents();
357   for (auto i = extents.begin(); i < extents.end(); i++) {
358     indexArgs.push_back(one);
359     indexArgs.push_back(*i);
360     indexArgs.push_back(one);
361   }
362 }
363 
364 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape,
365                               SmallVectorImpl<mlir::Value> &indexArgs,
366                               mlir::PatternRewriter &rewriter) {
367   auto one = rewriter.create<mlir::arith::ConstantOp>(
368       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
369   auto extents = shape.getPairs();
370   for (auto i = extents.begin(); i < extents.end();) {
371     indexArgs.push_back(*i++);
372     indexArgs.push_back(*i++);
373     indexArgs.push_back(one);
374   }
375 }
376 
377 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice,
378                               SmallVectorImpl<mlir::Value> &indexArgs,
379                               mlir::PatternRewriter &rewriter) {
380   auto extents = slice.getTriples();
381   for (auto i = extents.begin(); i < extents.end();) {
382     indexArgs.push_back(*i++);
383     indexArgs.push_back(*i++);
384     indexArgs.push_back(*i++);
385   }
386 }
387 
388 static void populateIndexArgs(fir::ArrayCoorOp acoOp,
389                               SmallVectorImpl<mlir::Value> &indexArgs,
390                               mlir::PatternRewriter &rewriter) {
391   if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>())
392     return populateIndexArgs(acoOp, shape, indexArgs, rewriter);
393   if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>())
394     return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter);
395   if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>())
396     return populateIndexArgs(acoOp, slice, indexArgs, rewriter);
397 }
398 
399 /// Returns affine.apply and fir.convert from array_coor and gendims
400 static std::pair<mlir::AffineApplyOp, fir::ConvertOp>
401 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) {
402   auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>();
403   auto affineMap =
404       createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext());
405   SmallVector<mlir::Value> indexArgs;
406   indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end());
407 
408   populateIndexArgs(acoOp, indexArgs, rewriter);
409 
410   auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
411                                                           affineMap, indexArgs);
412   auto arrayElementType = coordinateArrayElement(acoOp);
413   auto newType =
414       mlir::MemRefType::get({mlir::ShapedType::kDynamic}, arrayElementType);
415   auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType,
416                                                       acoOp.getMemref());
417   return std::make_pair(affineApply, arrayConvert);
418 }
419 
420 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) {
421   rewriter.setInsertionPoint(loadOp);
422   auto affineOps = createAffineOps(loadOp.getMemref(), rewriter);
423   rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>(
424       loadOp, affineOps.second.getResult(), affineOps.first.getResult());
425 }
426 
427 static void rewriteStore(fir::StoreOp storeOp,
428                          mlir::PatternRewriter &rewriter) {
429   rewriter.setInsertionPoint(storeOp);
430   auto affineOps = createAffineOps(storeOp.getMemref(), rewriter);
431   rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.getValue(),
432                                                    affineOps.second.getResult(),
433                                                    affineOps.first.getResult());
434 }
435 
436 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
437   for (auto &bodyOp : block->getOperations()) {
438     if (isa<fir::LoadOp>(bodyOp))
439       rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
440     if (isa<fir::StoreOp>(bodyOp))
441       rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
442   }
443 }
444 
445 namespace {
446 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to
447 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir
448 /// loads and stores to affine.
449 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
450 public:
451   using OpRewritePattern::OpRewritePattern;
452   AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
453       : OpRewritePattern(context), functionAnalysis(afa) {}
454 
455   mlir::LogicalResult
456   matchAndRewrite(fir::DoLoopOp loop,
457                   mlir::PatternRewriter &rewriter) const override {
458     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
459                loop.dump(););
460     LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
461         functionAnalysis.getChildLoopAnalysis(loop);
462     auto &loopOps = loop.getBody()->getOperations();
463     auto loopAndIndex = createAffineFor(loop, rewriter);
464     auto affineFor = loopAndIndex.first;
465     auto inductionVar = loopAndIndex.second;
466 
467     rewriter.startRootUpdate(affineFor.getOperation());
468     affineFor.getBody()->getOperations().splice(
469         std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
470         std::prev(loopOps.end()));
471     rewriter.finalizeRootUpdate(affineFor.getOperation());
472 
473     rewriter.startRootUpdate(loop.getOperation());
474     loop.getInductionVar().replaceAllUsesWith(inductionVar);
475     rewriter.finalizeRootUpdate(loop.getOperation());
476 
477     rewriteMemoryOps(affineFor.getBody(), rewriter);
478 
479     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
480                affineFor.dump(););
481     rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
482     return success();
483   }
484 
485 private:
486   std::pair<mlir::AffineForOp, mlir::Value>
487   createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
488     if (auto constantStep = constantIntegerLike(op.getStep()))
489       if (*constantStep > 0)
490         return positiveConstantStep(op, *constantStep, rewriter);
491     return genericBounds(op, rewriter);
492   }
493 
494   // when step for the loop is positive compile time constant
495   std::pair<mlir::AffineForOp, mlir::Value>
496   positiveConstantStep(fir::DoLoopOp op, int64_t step,
497                        mlir::PatternRewriter &rewriter) const {
498     auto affineFor = rewriter.create<mlir::AffineForOp>(
499         op.getLoc(), ValueRange(op.getLowerBound()),
500         mlir::AffineMap::get(0, 1,
501                              mlir::getAffineSymbolExpr(0, op.getContext())),
502         ValueRange(op.getUpperBound()),
503         mlir::AffineMap::get(0, 1,
504                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
505         step);
506     return std::make_pair(affineFor, affineFor.getInductionVar());
507   }
508 
509   std::pair<mlir::AffineForOp, mlir::Value>
510   genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
511     auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext());
512     auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext());
513     auto step = mlir::getAffineSymbolExpr(2, op.getContext());
514     mlir::AffineMap upperBoundMap = mlir::AffineMap::get(
515         0, 3, (upperBound - lowerBound + step).floorDiv(step));
516     auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>(
517         op.getLoc(), upperBoundMap,
518         ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()}));
519     auto actualIndexMap = mlir::AffineMap::get(
520         1, 2,
521         (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) *
522             mlir::getAffineSymbolExpr(1, op.getContext()));
523 
524     auto affineFor = rewriter.create<mlir::AffineForOp>(
525         op.getLoc(), ValueRange(),
526         AffineMap::getConstantMap(0, op.getContext()),
527         genericUpperBound.getResult(),
528         mlir::AffineMap::get(0, 1,
529                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
530         1);
531     rewriter.setInsertionPointToStart(affineFor.getBody());
532     auto actualIndex = rewriter.create<mlir::AffineApplyOp>(
533         op.getLoc(), actualIndexMap,
534         ValueRange(
535             {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()}));
536     return std::make_pair(affineFor, actualIndex.getResult());
537   }
538 
539   AffineFunctionAnalysis &functionAnalysis;
540 };
541 
542 /// Convert `fir.if` to `affine.if`.
543 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
544 public:
545   using OpRewritePattern::OpRewritePattern;
546   AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
547       : OpRewritePattern(context) {}
548   mlir::LogicalResult
549   matchAndRewrite(fir::IfOp op,
550                   mlir::PatternRewriter &rewriter) const override {
551     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
552                op.dump(););
553     auto &ifOps = op.getThenRegion().front().getOperations();
554     auto affineCondition = AffineIfCondition(op.getCondition());
555     if (!affineCondition.hasIntegerSet()) {
556       LLVM_DEBUG(
557           llvm::dbgs()
558               << "AffineIfConversion: couldn't calculate affine condition\n";);
559       return failure();
560     }
561     auto affineIf = rewriter.create<mlir::AffineIfOp>(
562         op.getLoc(), affineCondition.getIntegerSet(),
563         affineCondition.getAffineArgs(), !op.getElseRegion().empty());
564     rewriter.startRootUpdate(affineIf);
565     affineIf.getThenBlock()->getOperations().splice(
566         std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
567         std::prev(ifOps.end()));
568     if (!op.getElseRegion().empty()) {
569       auto &otherOps = op.getElseRegion().front().getOperations();
570       affineIf.getElseBlock()->getOperations().splice(
571           std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
572           std::prev(otherOps.end()));
573     }
574     rewriter.finalizeRootUpdate(affineIf);
575     rewriteMemoryOps(affineIf.getBody(), rewriter);
576 
577     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
578                affineIf.dump(););
579     rewriter.replaceOp(op, affineIf.getOperation()->getResults());
580     return success();
581   }
582 };
583 
584 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
585 /// where such a promotion is possible.
586 class AffineDialectPromotion
587     : public fir::impl::AffineDialectPromotionBase<AffineDialectPromotion> {
588 public:
589   void runOnOperation() override {
590 
591     auto *context = &getContext();
592     auto function = getOperation();
593     markAllAnalysesPreserved();
594     auto functionAnalysis = AffineFunctionAnalysis(function);
595     mlir::RewritePatternSet patterns(context);
596     patterns.insert<AffineIfConversion>(context, functionAnalysis);
597     patterns.insert<AffineLoopConversion>(context, functionAnalysis);
598     mlir::ConversionTarget target = *context;
599     target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
600                            mlir::scf::SCFDialect, mlir::arith::ArithDialect,
601                            mlir::func::FuncDialect>();
602     target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
603       return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
604     });
605     target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
606                                                fir::DoLoopOp op) {
607       return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
608     });
609 
610     LLVM_DEBUG(llvm::dbgs()
611                    << "AffineDialectPromotion: running promotion on: \n";
612                function.print(llvm::dbgs()););
613     // apply the patterns
614     if (mlir::failed(mlir::applyPartialConversion(function, target,
615                                                   std::move(patterns)))) {
616       mlir::emitError(mlir::UnknownLoc::get(context),
617                       "error in converting to affine dialect\n");
618       signalPassFailure();
619     }
620   }
621 };
622 } // namespace
623 
624 /// Convert FIR loop constructs to the Affine dialect
625 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() {
626   return std::make_unique<AffineDialectPromotion>();
627 }
628