xref: /llvm-project/flang/lib/Optimizer/Transforms/AffinePromotion.cpp (revision c09215860fd5c32012ef4fdc5a001485a04fe85a)
1 //===-- AffinePromotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation is a prototype that promote FIR loops operations
10 // to affine dialect operations.
11 // It is not part of the production pipeline and would need more work in order
12 // to be used in production.
13 // More information can be found in this presentation:
14 // https://slides.com/rajanwalia/deck
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "flang/Optimizer/Dialect/FIRDialect.h"
19 #include "flang/Optimizer/Dialect/FIROps.h"
20 #include "flang/Optimizer/Dialect/FIRType.h"
21 #include "flang/Optimizer/Transforms/Passes.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/IR/BuiltinAttributes.h"
26 #include "mlir/IR/IntegerSet.h"
27 #include "mlir/IR/Visitors.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/Support/Debug.h"
32 #include <optional>
33 
34 namespace fir {
35 #define GEN_PASS_DEF_AFFINEDIALECTPROMOTION
36 #include "flang/Optimizer/Transforms/Passes.h.inc"
37 } // namespace fir
38 
39 #define DEBUG_TYPE "flang-affine-promotion"
40 
41 using namespace fir;
42 using namespace mlir;
43 
44 namespace {
45 struct AffineLoopAnalysis;
46 struct AffineIfAnalysis;
47 
48 /// Stores analysis objects for all loops and if operations inside a function
49 /// these analysis are used twice, first for marking operations for rewrite and
50 /// second when doing rewrite.
51 struct AffineFunctionAnalysis {
52   explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) {
53     for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
54       loopAnalysisMap.try_emplace(op, op, *this);
55   }
56 
57   AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
58 
59   AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const;
60 
61   llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap;
62   llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap;
63 };
64 } // namespace
65 
66 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) {
67   if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) {
68     if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()))
69       return true;
70     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a "
71                                "loop induction variable (owner not loopOp)\n";
72                op->dump());
73     return false;
74   }
75   LLVM_DEBUG(
76       llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop "
77                       "induction variable (not a block argument)\n";
78       op->dump(); coordinate.getDefiningOp()->dump());
79   return false;
80 }
81 
82 namespace {
83 struct AffineLoopAnalysis {
84   AffineLoopAnalysis() = default;
85 
86   explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa)
87       : legality(analyzeLoop(op, afa)) {}
88 
89   bool canPromoteToAffine() { return legality; }
90 
91 private:
92   bool analyzeBody(fir::DoLoopOp loopOperation,
93                    AffineFunctionAnalysis &functionAnalysis) {
94     for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) {
95       auto analysis = functionAnalysis.loopAnalysisMap
96                           .try_emplace(loopOp, loopOp, functionAnalysis)
97                           .first->getSecond();
98       if (!analysis.canPromoteToAffine())
99         return false;
100     }
101     for (auto ifOp : loopOperation.getOps<fir::IfOp>())
102       functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis);
103     return true;
104   }
105 
106   bool analyzeLoop(fir::DoLoopOp loopOperation,
107                    AffineFunctionAnalysis &functionAnalysis) {
108     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
109     return analyzeMemoryAccess(loopOperation) &&
110            analyzeBody(loopOperation, functionAnalysis);
111   }
112 
113   bool analyzeReference(mlir::Value memref, mlir::Operation *op) {
114     if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) {
115       if (acoOp.getMemref().getType().isa<fir::BoxType>()) {
116         // TODO: Look if and how fir.box can be promoted to affine.
117         LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, "
118                                    "array memory operation uses fir.box\n";
119                    op->dump(); acoOp.dump(););
120         return false;
121       }
122       bool canPromote = true;
123       for (auto coordinate : acoOp.getIndices())
124         canPromote = canPromote && analyzeCoordinate(coordinate, op);
125       return canPromote;
126     }
127     if (auto coOp = memref.getDefiningOp<CoordinateOp>()) {
128       LLVM_DEBUG(llvm::dbgs()
129                      << "AffineLoopAnalysis: cannot promote loop, "
130                         "array memory operation uses non ArrayCoorOp\n";
131                  op->dump(); coOp.dump(););
132 
133       return false;
134     }
135     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory "
136                                "reference for array load\n";
137                op->dump(););
138     return false;
139   }
140 
141   bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) {
142     for (auto loadOp : loopOperation.getOps<fir::LoadOp>())
143       if (!analyzeReference(loadOp.getMemref(), loadOp))
144         return false;
145     for (auto storeOp : loopOperation.getOps<fir::StoreOp>())
146       if (!analyzeReference(storeOp.getMemref(), storeOp))
147         return false;
148     return true;
149   }
150 
151   bool legality{};
152 };
153 } // namespace
154 
155 AffineLoopAnalysis
156 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const {
157   auto it = loopAnalysisMap.find_as(op);
158   if (it == loopAnalysisMap.end()) {
159     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
160                op.dump(););
161     op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n");
162     return {};
163   }
164   return it->getSecond();
165 }
166 
167 namespace {
168 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the
169 /// final number of symbols and dimensions of the affine map. Integer set if
170 /// possible is in Optional IntegerSet.
171 struct AffineIfCondition {
172   using MaybeAffineExpr = std::optional<mlir::AffineExpr>;
173 
174   explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) {
175     if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>())
176       fromCmpIOp(condDef);
177   }
178 
179   bool hasIntegerSet() const { return integerSet.has_value(); }
180 
181   mlir::IntegerSet getIntegerSet() const {
182     assert(hasIntegerSet() && "integer set is missing");
183     return *integerSet;
184   }
185 
186   mlir::ValueRange getAffineArgs() const { return affineArgs; }
187 
188 private:
189   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs,
190                                  mlir::Value rhs) {
191     return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs));
192   }
193 
194   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs,
195                                  MaybeAffineExpr rhs) {
196     if (lhs && rhs)
197       return mlir::getAffineBinaryOpExpr(kind, *lhs, *rhs);
198     return {};
199   }
200 
201   MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; }
202 
203   MaybeAffineExpr toAffineExpr(int64_t value) {
204     return {mlir::getAffineConstantExpr(value, firCondition.getContext())};
205   }
206 
207   /// Returns an AffineExpr if it is a result of operations that can be done
208   /// in an affine expression, this includes -, +, *, rem, constant.
209   /// block arguments of a loopOp or forOp are used as dimensions
210   MaybeAffineExpr toAffineExpr(mlir::Value value) {
211     if (auto op = value.getDefiningOp<mlir::arith::SubIOp>())
212       return affineBinaryOp(
213           mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()),
214           affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()),
215                          toAffineExpr(-1)));
216     if (auto op = value.getDefiningOp<mlir::arith::AddIOp>())
217       return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(),
218                             op.getRhs());
219     if (auto op = value.getDefiningOp<mlir::arith::MulIOp>())
220       return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(),
221                             op.getRhs());
222     if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>())
223       return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(),
224                             op.getRhs());
225     if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>())
226       if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>())
227         return toAffineExpr(intConstant.getInt());
228     if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
229       affineArgs.push_back(value);
230       if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) ||
231           isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp()))
232         return {mlir::getAffineDimExpr(dimCount++, value.getContext())};
233       return {mlir::getAffineSymbolExpr(symCount++, value.getContext())};
234     }
235     return {};
236   }
237 
238   void fromCmpIOp(mlir::arith::CmpIOp cmpOp) {
239     auto lhsAffine = toAffineExpr(cmpOp.getLhs());
240     auto rhsAffine = toAffineExpr(cmpOp.getRhs());
241     if (!lhsAffine || !rhsAffine)
242       return;
243     auto constraintPair =
244         constraint(cmpOp.getPredicate(), *rhsAffine - *lhsAffine);
245     if (!constraintPair)
246       return;
247     integerSet = mlir::IntegerSet::get(
248         dimCount, symCount, {constraintPair->first}, {constraintPair->second});
249   }
250 
251   std::optional<std::pair<AffineExpr, bool>>
252   constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) {
253     switch (predicate) {
254     case mlir::arith::CmpIPredicate::slt:
255       return {std::make_pair(basic - 1, false)};
256     case mlir::arith::CmpIPredicate::sle:
257       return {std::make_pair(basic, false)};
258     case mlir::arith::CmpIPredicate::sgt:
259       return {std::make_pair(1 - basic, false)};
260     case mlir::arith::CmpIPredicate::sge:
261       return {std::make_pair(0 - basic, false)};
262     case mlir::arith::CmpIPredicate::eq:
263       return {std::make_pair(basic, true)};
264     default:
265       return {};
266     }
267   }
268 
269   llvm::SmallVector<mlir::Value> affineArgs;
270   std::optional<mlir::IntegerSet> integerSet;
271   mlir::Value firCondition;
272   unsigned symCount{0u};
273   unsigned dimCount{0u};
274 };
275 } // namespace
276 
277 namespace {
278 /// Analysis for affine promotion of fir.if
279 struct AffineIfAnalysis {
280   AffineIfAnalysis() = default;
281 
282   explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa)
283       : legality(analyzeIf(op, afa)) {}
284 
285   bool canPromoteToAffine() { return legality; }
286 
287 private:
288   bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) {
289     if (op.getNumResults() == 0)
290       return true;
291     LLVM_DEBUG(llvm::dbgs()
292                    << "AffineIfAnalysis: not promoting as op has results\n";);
293     return false;
294   }
295 
296   bool legality{};
297 };
298 } // namespace
299 
300 AffineIfAnalysis
301 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const {
302   auto it = ifAnalysisMap.find_as(op);
303   if (it == ifAnalysisMap.end()) {
304     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
305                op.dump(););
306     op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n");
307     return {};
308   }
309   return it->getSecond();
310 }
311 
312 /// AffineMap rewriting fir.array_coor operation to affine apply,
313 /// %dim = fir.gendim %lowerBound, %upperBound, %stride
314 /// %a = fir.array_coor %arr(%dim) %i
315 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)>
316 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,
317                                                  MLIRContext *context) {
318   auto index = mlir::getAffineConstantExpr(0, context);
319   auto accuExtent = mlir::getAffineConstantExpr(1, context);
320   for (unsigned i = 0; i < dimensions; ++i) {
321     mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context),
322                      lowerBound = mlir::getAffineSymbolExpr(i * 3, context),
323                      currentExtent =
324                          mlir::getAffineSymbolExpr(i * 3 + 1, context),
325                      stride = mlir::getAffineSymbolExpr(i * 3 + 2, context),
326                      currentPart = (idx * stride - lowerBound) * accuExtent;
327     index = currentPart + index;
328     accuExtent = accuExtent * currentExtent;
329   }
330   return mlir::AffineMap::get(dimensions, dimensions * 3, index);
331 }
332 
333 static std::optional<int64_t> constantIntegerLike(const mlir::Value value) {
334   if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>())
335     if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>())
336       return stepAttr.getInt();
337   return {};
338 }
339 
340 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) {
341   if (auto refType =
342           op.getMemref().getType().dyn_cast_or_null<ReferenceType>()) {
343     if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) {
344       return seqType.getEleTy();
345     }
346   }
347   op.emitError(
348       "AffineLoopConversion: array type in coordinate operation not valid\n");
349   return mlir::Type();
350 }
351 
352 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape,
353                               SmallVectorImpl<mlir::Value> &indexArgs,
354                               mlir::PatternRewriter &rewriter) {
355   auto one = rewriter.create<mlir::arith::ConstantOp>(
356       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
357   auto extents = shape.getExtents();
358   for (auto i = extents.begin(); i < extents.end(); i++) {
359     indexArgs.push_back(one);
360     indexArgs.push_back(*i);
361     indexArgs.push_back(one);
362   }
363 }
364 
365 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape,
366                               SmallVectorImpl<mlir::Value> &indexArgs,
367                               mlir::PatternRewriter &rewriter) {
368   auto one = rewriter.create<mlir::arith::ConstantOp>(
369       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
370   auto extents = shape.getPairs();
371   for (auto i = extents.begin(); i < extents.end();) {
372     indexArgs.push_back(*i++);
373     indexArgs.push_back(*i++);
374     indexArgs.push_back(one);
375   }
376 }
377 
378 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice,
379                               SmallVectorImpl<mlir::Value> &indexArgs,
380                               mlir::PatternRewriter &rewriter) {
381   auto extents = slice.getTriples();
382   for (auto i = extents.begin(); i < extents.end();) {
383     indexArgs.push_back(*i++);
384     indexArgs.push_back(*i++);
385     indexArgs.push_back(*i++);
386   }
387 }
388 
389 static void populateIndexArgs(fir::ArrayCoorOp acoOp,
390                               SmallVectorImpl<mlir::Value> &indexArgs,
391                               mlir::PatternRewriter &rewriter) {
392   if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>())
393     return populateIndexArgs(acoOp, shape, indexArgs, rewriter);
394   if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>())
395     return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter);
396   if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>())
397     return populateIndexArgs(acoOp, slice, indexArgs, rewriter);
398 }
399 
400 /// Returns affine.apply and fir.convert from array_coor and gendims
401 static std::pair<mlir::AffineApplyOp, fir::ConvertOp>
402 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) {
403   auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>();
404   auto affineMap =
405       createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext());
406   SmallVector<mlir::Value> indexArgs;
407   indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end());
408 
409   populateIndexArgs(acoOp, indexArgs, rewriter);
410 
411   auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
412                                                           affineMap, indexArgs);
413   auto arrayElementType = coordinateArrayElement(acoOp);
414   auto newType =
415       mlir::MemRefType::get({mlir::ShapedType::kDynamic}, arrayElementType);
416   auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType,
417                                                       acoOp.getMemref());
418   return std::make_pair(affineApply, arrayConvert);
419 }
420 
421 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) {
422   rewriter.setInsertionPoint(loadOp);
423   auto affineOps = createAffineOps(loadOp.getMemref(), rewriter);
424   rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>(
425       loadOp, affineOps.second.getResult(), affineOps.first.getResult());
426 }
427 
428 static void rewriteStore(fir::StoreOp storeOp,
429                          mlir::PatternRewriter &rewriter) {
430   rewriter.setInsertionPoint(storeOp);
431   auto affineOps = createAffineOps(storeOp.getMemref(), rewriter);
432   rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.getValue(),
433                                                    affineOps.second.getResult(),
434                                                    affineOps.first.getResult());
435 }
436 
437 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
438   for (auto &bodyOp : block->getOperations()) {
439     if (isa<fir::LoadOp>(bodyOp))
440       rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
441     if (isa<fir::StoreOp>(bodyOp))
442       rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
443   }
444 }
445 
446 namespace {
447 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to
448 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir
449 /// loads and stores to affine.
450 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
451 public:
452   using OpRewritePattern::OpRewritePattern;
453   AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
454       : OpRewritePattern(context), functionAnalysis(afa) {}
455 
456   mlir::LogicalResult
457   matchAndRewrite(fir::DoLoopOp loop,
458                   mlir::PatternRewriter &rewriter) const override {
459     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
460                loop.dump(););
461     LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
462         functionAnalysis.getChildLoopAnalysis(loop);
463     auto &loopOps = loop.getBody()->getOperations();
464     auto loopAndIndex = createAffineFor(loop, rewriter);
465     auto affineFor = loopAndIndex.first;
466     auto inductionVar = loopAndIndex.second;
467 
468     rewriter.startRootUpdate(affineFor.getOperation());
469     affineFor.getBody()->getOperations().splice(
470         std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
471         std::prev(loopOps.end()));
472     rewriter.finalizeRootUpdate(affineFor.getOperation());
473 
474     rewriter.startRootUpdate(loop.getOperation());
475     loop.getInductionVar().replaceAllUsesWith(inductionVar);
476     rewriter.finalizeRootUpdate(loop.getOperation());
477 
478     rewriteMemoryOps(affineFor.getBody(), rewriter);
479 
480     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
481                affineFor.dump(););
482     rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
483     return success();
484   }
485 
486 private:
487   std::pair<mlir::AffineForOp, mlir::Value>
488   createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
489     if (auto constantStep = constantIntegerLike(op.getStep()))
490       if (*constantStep > 0)
491         return positiveConstantStep(op, *constantStep, rewriter);
492     return genericBounds(op, rewriter);
493   }
494 
495   // when step for the loop is positive compile time constant
496   std::pair<mlir::AffineForOp, mlir::Value>
497   positiveConstantStep(fir::DoLoopOp op, int64_t step,
498                        mlir::PatternRewriter &rewriter) const {
499     auto affineFor = rewriter.create<mlir::AffineForOp>(
500         op.getLoc(), ValueRange(op.getLowerBound()),
501         mlir::AffineMap::get(0, 1,
502                              mlir::getAffineSymbolExpr(0, op.getContext())),
503         ValueRange(op.getUpperBound()),
504         mlir::AffineMap::get(0, 1,
505                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
506         step);
507     return std::make_pair(affineFor, affineFor.getInductionVar());
508   }
509 
510   std::pair<mlir::AffineForOp, mlir::Value>
511   genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
512     auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext());
513     auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext());
514     auto step = mlir::getAffineSymbolExpr(2, op.getContext());
515     mlir::AffineMap upperBoundMap = mlir::AffineMap::get(
516         0, 3, (upperBound - lowerBound + step).floorDiv(step));
517     auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>(
518         op.getLoc(), upperBoundMap,
519         ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()}));
520     auto actualIndexMap = mlir::AffineMap::get(
521         1, 2,
522         (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) *
523             mlir::getAffineSymbolExpr(1, op.getContext()));
524 
525     auto affineFor = rewriter.create<mlir::AffineForOp>(
526         op.getLoc(), ValueRange(),
527         AffineMap::getConstantMap(0, op.getContext()),
528         genericUpperBound.getResult(),
529         mlir::AffineMap::get(0, 1,
530                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
531         1);
532     rewriter.setInsertionPointToStart(affineFor.getBody());
533     auto actualIndex = rewriter.create<mlir::AffineApplyOp>(
534         op.getLoc(), actualIndexMap,
535         ValueRange(
536             {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()}));
537     return std::make_pair(affineFor, actualIndex.getResult());
538   }
539 
540   AffineFunctionAnalysis &functionAnalysis;
541 };
542 
543 /// Convert `fir.if` to `affine.if`.
544 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
545 public:
546   using OpRewritePattern::OpRewritePattern;
547   AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
548       : OpRewritePattern(context) {}
549   mlir::LogicalResult
550   matchAndRewrite(fir::IfOp op,
551                   mlir::PatternRewriter &rewriter) const override {
552     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
553                op.dump(););
554     auto &ifOps = op.getThenRegion().front().getOperations();
555     auto affineCondition = AffineIfCondition(op.getCondition());
556     if (!affineCondition.hasIntegerSet()) {
557       LLVM_DEBUG(
558           llvm::dbgs()
559               << "AffineIfConversion: couldn't calculate affine condition\n";);
560       return failure();
561     }
562     auto affineIf = rewriter.create<mlir::AffineIfOp>(
563         op.getLoc(), affineCondition.getIntegerSet(),
564         affineCondition.getAffineArgs(), !op.getElseRegion().empty());
565     rewriter.startRootUpdate(affineIf);
566     affineIf.getThenBlock()->getOperations().splice(
567         std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
568         std::prev(ifOps.end()));
569     if (!op.getElseRegion().empty()) {
570       auto &otherOps = op.getElseRegion().front().getOperations();
571       affineIf.getElseBlock()->getOperations().splice(
572           std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
573           std::prev(otherOps.end()));
574     }
575     rewriter.finalizeRootUpdate(affineIf);
576     rewriteMemoryOps(affineIf.getBody(), rewriter);
577 
578     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
579                affineIf.dump(););
580     rewriter.replaceOp(op, affineIf.getOperation()->getResults());
581     return success();
582   }
583 };
584 
585 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
586 /// where such a promotion is possible.
587 class AffineDialectPromotion
588     : public fir::impl::AffineDialectPromotionBase<AffineDialectPromotion> {
589 public:
590   void runOnOperation() override {
591 
592     auto *context = &getContext();
593     auto function = getOperation();
594     markAllAnalysesPreserved();
595     auto functionAnalysis = AffineFunctionAnalysis(function);
596     mlir::RewritePatternSet patterns(context);
597     patterns.insert<AffineIfConversion>(context, functionAnalysis);
598     patterns.insert<AffineLoopConversion>(context, functionAnalysis);
599     mlir::ConversionTarget target = *context;
600     target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
601                            mlir::scf::SCFDialect, mlir::arith::ArithDialect,
602                            mlir::func::FuncDialect>();
603     target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
604       return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
605     });
606     target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
607                                                fir::DoLoopOp op) {
608       return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
609     });
610 
611     LLVM_DEBUG(llvm::dbgs()
612                    << "AffineDialectPromotion: running promotion on: \n";
613                function.print(llvm::dbgs()););
614     // apply the patterns
615     if (mlir::failed(mlir::applyPartialConversion(function, target,
616                                                   std::move(patterns)))) {
617       mlir::emitError(mlir::UnknownLoc::get(context),
618                       "error in converting to affine dialect\n");
619       signalPassFailure();
620     }
621   }
622 };
623 } // namespace
624 
625 /// Convert FIR loop constructs to the Affine dialect
626 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() {
627   return std::make_unique<AffineDialectPromotion>();
628 }
629