xref: /llvm-project/flang/lib/Optimizer/Transforms/LoopVersioning.cpp (revision 711419e3025678511e3d26c4c30d757f9029d598)
1a716ace1SMats Petersson //===- LoopVersioning.cpp -------------------------------------------------===//
2a716ace1SMats Petersson //
3a716ace1SMats Petersson // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a716ace1SMats Petersson // See https://llvm.org/LICENSE.txt for license information.
5a716ace1SMats Petersson // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a716ace1SMats Petersson //
7a716ace1SMats Petersson //===----------------------------------------------------------------------===//
8a716ace1SMats Petersson 
9a716ace1SMats Petersson //===----------------------------------------------------------------------===//
10a716ace1SMats Petersson /// \file
11a716ace1SMats Petersson /// This pass looks for loops iterating over assumed-shape arrays, that can
12a716ace1SMats Petersson /// be optimized by "guessing" that the stride is element-sized.
13a716ace1SMats Petersson ///
14208a4510SValentin Clement /// This is done by creating two versions of the same loop: one which assumes
15a716ace1SMats Petersson /// that the elements are contiguous (stride == size of element), and one that
16a716ace1SMats Petersson /// is the original generic loop.
17a716ace1SMats Petersson ///
18a716ace1SMats Petersson /// As a side-effect of the assumed element size stride, the array is also
19a716ace1SMats Petersson /// flattened to make it a 1D array - this is because the internal array
20a716ace1SMats Petersson /// structure must be either 1D or have known sizes in all dimensions - and at
21a716ace1SMats Petersson /// least one of the dimensions here is already unknown.
22a716ace1SMats Petersson ///
23a716ace1SMats Petersson /// There are two distinct benefits here:
24a716ace1SMats Petersson /// 1. The loop that iterates over the elements is somewhat simplified by the
25a716ace1SMats Petersson ///    constant stride calculation.
26a716ace1SMats Petersson /// 2. Since the compiler can understand the size of the stride, it can use
27a716ace1SMats Petersson ///    vector instructions, where an unknown (at compile time) stride does often
28a716ace1SMats Petersson ///    prevent vector operations from being used.
29a716ace1SMats Petersson ///
30a716ace1SMats Petersson /// A known drawback is that the code-size is increased, in some cases that can
31a716ace1SMats Petersson /// be quite substantial - 3-4x is quite plausible (this includes that the loop
32a716ace1SMats Petersson /// gets vectorized, which in itself often more than doubles the size of the
33a716ace1SMats Petersson /// code, because unless the loop size is known, there will be a modulo
34a716ace1SMats Petersson /// vector-size remainder to deal with.
35a716ace1SMats Petersson ///
36a716ace1SMats Petersson /// TODO: Do we need some size limit where loops no longer get duplicated?
37a716ace1SMats Petersson //        Maybe some sort of cost analysis.
38a716ace1SMats Petersson /// TODO: Should some loop content - for example calls to functions and
39a716ace1SMats Petersson ///       subroutines inhibit the versioning of the loops. Plausibly, this
40a716ace1SMats Petersson ///       could be part of the cost analysis above.
41a716ace1SMats Petersson //===----------------------------------------------------------------------===//
42a716ace1SMats Petersson 
43668f261bSSlava Zakharin #include "flang/ISO_Fortran_binding_wrapper.h"
44a716ace1SMats Petersson #include "flang/Optimizer/Builder/BoxValue.h"
45a716ace1SMats Petersson #include "flang/Optimizer/Builder/FIRBuilder.h"
46a716ace1SMats Petersson #include "flang/Optimizer/Builder/Runtime/Inquiry.h"
47a716ace1SMats Petersson #include "flang/Optimizer/Dialect/FIRDialect.h"
48a716ace1SMats Petersson #include "flang/Optimizer/Dialect/FIROps.h"
49a716ace1SMats Petersson #include "flang/Optimizer/Dialect/FIRType.h"
50a716ace1SMats Petersson #include "flang/Optimizer/Dialect/Support/FIRContext.h"
51a716ace1SMats Petersson #include "flang/Optimizer/Dialect/Support/KindMapping.h"
523e47e75fSSlava Zakharin #include "flang/Optimizer/Support/DataLayout.h"
53a716ace1SMats Petersson #include "flang/Optimizer/Transforms/Passes.h"
54a716ace1SMats Petersson #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
558dcee580SMats Petersson #include "mlir/IR/Dominance.h"
56a716ace1SMats Petersson #include "mlir/IR/Matchers.h"
57a716ace1SMats Petersson #include "mlir/IR/TypeUtilities.h"
58a716ace1SMats Petersson #include "mlir/Pass/Pass.h"
59a716ace1SMats Petersson #include "mlir/Transforms/DialectConversion.h"
60a716ace1SMats Petersson #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
61a716ace1SMats Petersson #include "mlir/Transforms/RegionUtils.h"
62a716ace1SMats Petersson #include "llvm/Support/Debug.h"
63a716ace1SMats Petersson #include "llvm/Support/raw_ostream.h"
64a716ace1SMats Petersson 
65a716ace1SMats Petersson #include <algorithm>
66a716ace1SMats Petersson 
67a716ace1SMats Petersson namespace fir {
68a716ace1SMats Petersson #define GEN_PASS_DEF_LOOPVERSIONING
69a716ace1SMats Petersson #include "flang/Optimizer/Transforms/Passes.h.inc"
70a716ace1SMats Petersson } // namespace fir
71a716ace1SMats Petersson 
72a716ace1SMats Petersson #define DEBUG_TYPE "flang-loop-versioning"
73a716ace1SMats Petersson 
74a716ace1SMats Petersson namespace {
75a716ace1SMats Petersson 
76a716ace1SMats Petersson class LoopVersioningPass
77a716ace1SMats Petersson     : public fir::impl::LoopVersioningBase<LoopVersioningPass> {
78a716ace1SMats Petersson public:
79a716ace1SMats Petersson   void runOnOperation() override;
80a716ace1SMats Petersson };
81a716ace1SMats Petersson 
827beb65aeSSlava Zakharin /// @struct ArgInfo
837beb65aeSSlava Zakharin /// A structure to hold an argument, the size of the argument and dimension
847beb65aeSSlava Zakharin /// information.
857beb65aeSSlava Zakharin struct ArgInfo {
867beb65aeSSlava Zakharin   mlir::Value arg;
877beb65aeSSlava Zakharin   size_t size;
887beb65aeSSlava Zakharin   unsigned rank;
897beb65aeSSlava Zakharin   fir::BoxDimsOp dims[CFI_MAX_RANK];
907beb65aeSSlava Zakharin };
917beb65aeSSlava Zakharin 
927beb65aeSSlava Zakharin /// @struct ArgsUsageInLoop
937beb65aeSSlava Zakharin /// A structure providing information about the function arguments
947beb65aeSSlava Zakharin /// usage by the instructions immediately nested in a loop.
957beb65aeSSlava Zakharin struct ArgsUsageInLoop {
967beb65aeSSlava Zakharin   /// Mapping between the memref operand of an array indexing
977beb65aeSSlava Zakharin   /// operation (e.g. fir.coordinate_of) and the argument information.
987beb65aeSSlava Zakharin   llvm::DenseMap<mlir::Value, ArgInfo> usageInfo;
997beb65aeSSlava Zakharin   /// Some array indexing operations inside a loop cannot be transformed.
1007beb65aeSSlava Zakharin   /// This vector holds the memref operands of such operations.
1017beb65aeSSlava Zakharin   /// The vector is used to make sure that we do not try to transform
1027beb65aeSSlava Zakharin   /// any outer loop, since this will imply the operation rewrite
1037beb65aeSSlava Zakharin   /// in this loop.
1047beb65aeSSlava Zakharin   llvm::SetVector<mlir::Value> cannotTransform;
1057beb65aeSSlava Zakharin 
1067beb65aeSSlava Zakharin   // Debug dump of the structure members assuming that
1077beb65aeSSlava Zakharin   // the information has been collected for the given loop.
1087beb65aeSSlava Zakharin   void dump(fir::DoLoopOp loop) const {
1094056287dSDavid Green     LLVM_DEBUG({
1107beb65aeSSlava Zakharin       mlir::OpPrintingFlags printFlags;
1117beb65aeSSlava Zakharin       printFlags.skipRegions();
1127beb65aeSSlava Zakharin       llvm::dbgs() << "Arguments usage info for loop:\n";
1137beb65aeSSlava Zakharin       loop.print(llvm::dbgs(), printFlags);
1147beb65aeSSlava Zakharin       llvm::dbgs() << "\nUsed args:\n";
1157beb65aeSSlava Zakharin       for (auto &use : usageInfo) {
1167beb65aeSSlava Zakharin         mlir::Value v = use.first;
1177beb65aeSSlava Zakharin         v.print(llvm::dbgs(), printFlags);
1187beb65aeSSlava Zakharin         llvm::dbgs() << "\n";
1197beb65aeSSlava Zakharin       }
1207beb65aeSSlava Zakharin       llvm::dbgs() << "\nCannot transform args:\n";
1217beb65aeSSlava Zakharin       for (mlir::Value arg : cannotTransform) {
1227beb65aeSSlava Zakharin         arg.print(llvm::dbgs(), printFlags);
1237beb65aeSSlava Zakharin         llvm::dbgs() << "\n";
1247beb65aeSSlava Zakharin       }
1254056287dSDavid Green       llvm::dbgs() << "====\n";
1264056287dSDavid Green     });
1277beb65aeSSlava Zakharin   }
1287beb65aeSSlava Zakharin 
1297beb65aeSSlava Zakharin   // Erase usageInfo and cannotTransform entries for a set
1307beb65aeSSlava Zakharin   // of given arguments.
1317beb65aeSSlava Zakharin   void eraseUsage(const llvm::SetVector<mlir::Value> &args) {
1327beb65aeSSlava Zakharin     for (auto &arg : args)
1337beb65aeSSlava Zakharin       usageInfo.erase(arg);
1347beb65aeSSlava Zakharin     cannotTransform.set_subtract(args);
1357beb65aeSSlava Zakharin   }
1367beb65aeSSlava Zakharin 
1377beb65aeSSlava Zakharin   // Erase usageInfo and cannotTransform entries for a set
1387beb65aeSSlava Zakharin   // of given arguments provided in the form of usageInfo map.
1397beb65aeSSlava Zakharin   void eraseUsage(const llvm::DenseMap<mlir::Value, ArgInfo> &args) {
1407beb65aeSSlava Zakharin     for (auto &arg : args) {
1417beb65aeSSlava Zakharin       usageInfo.erase(arg.first);
1427beb65aeSSlava Zakharin       cannotTransform.remove(arg.first);
1437beb65aeSSlava Zakharin     }
1447beb65aeSSlava Zakharin   }
1457beb65aeSSlava Zakharin };
146a716ace1SMats Petersson } // namespace
147a716ace1SMats Petersson 
148*711419e3SSlava Zakharin static fir::SequenceType getAsSequenceType(mlir::Value v) {
149*711419e3SSlava Zakharin   mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v.getType()));
150fac349a1SChristian Sigg   return mlir::dyn_cast<fir::SequenceType>(argTy);
151a716ace1SMats Petersson }
152a716ace1SMats Petersson 
153*711419e3SSlava Zakharin /// Return the rank and the element size (in bytes) of the given
154*711419e3SSlava Zakharin /// value \p v. If it is not an array or the element type is not
155*711419e3SSlava Zakharin /// supported, then return <0, 0>. Only trivial data types
156*711419e3SSlava Zakharin /// are currently supported.
157*711419e3SSlava Zakharin /// When \p isArgument is true, \p v is assumed to be a function
158*711419e3SSlava Zakharin /// argument. If \p v's type does not look like a type of an assumed
159*711419e3SSlava Zakharin /// shape array, then the function returns <0, 0>.
160*711419e3SSlava Zakharin /// When \p isArgument is false, array types with known innermost
161*711419e3SSlava Zakharin /// dimension are allowed to proceed.
162*711419e3SSlava Zakharin static std::pair<unsigned, size_t>
163*711419e3SSlava Zakharin getRankAndElementSize(const fir::KindMapping &kindMap,
164*711419e3SSlava Zakharin                       const mlir::DataLayout &dl, mlir::Value v,
165*711419e3SSlava Zakharin                       bool isArgument = false) {
166*711419e3SSlava Zakharin   if (auto seqTy = getAsSequenceType(v)) {
167*711419e3SSlava Zakharin     unsigned rank = seqTy.getDimension();
168*711419e3SSlava Zakharin     if (rank > 0 &&
169*711419e3SSlava Zakharin         (!isArgument ||
170*711419e3SSlava Zakharin          seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent())) {
171*711419e3SSlava Zakharin       size_t typeSize = 0;
172*711419e3SSlava Zakharin       mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(v.getType());
173*711419e3SSlava Zakharin       if (fir::isa_trivial(elementType)) {
174*711419e3SSlava Zakharin         auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash(
175*711419e3SSlava Zakharin             v.getLoc(), elementType, dl, kindMap);
176*711419e3SSlava Zakharin         typeSize = llvm::alignTo(eleSize, eleAlign);
177*711419e3SSlava Zakharin       }
178*711419e3SSlava Zakharin       if (typeSize)
179*711419e3SSlava Zakharin         return {rank, typeSize};
180*711419e3SSlava Zakharin     }
181*711419e3SSlava Zakharin   }
182*711419e3SSlava Zakharin 
183*711419e3SSlava Zakharin   LLVM_DEBUG(llvm::dbgs() << "Unsupported rank/type: " << v << '\n');
184*711419e3SSlava Zakharin   return {0, 0};
185*711419e3SSlava Zakharin }
186*711419e3SSlava Zakharin 
18705011024STom Eccles /// if a value comes from a fir.declare, follow it to the original source,
18805011024STom Eccles /// otherwise return the value
18905011024STom Eccles static mlir::Value unwrapFirDeclare(mlir::Value val) {
19005011024STom Eccles   // fir.declare is for source code variables. We don't have declares of
19105011024STom Eccles   // declares
19205011024STom Eccles   if (fir::DeclareOp declare = val.getDefiningOp<fir::DeclareOp>())
19305011024STom Eccles     return declare.getMemref();
19405011024STom Eccles   return val;
19505011024STom Eccles }
19605011024STom Eccles 
197*711419e3SSlava Zakharin /// Return true, if \p rebox operation keeps the input array
198*711419e3SSlava Zakharin /// continuous in the innermost dimension, if it is initially continuous
199*711419e3SSlava Zakharin /// in the innermost dimension.
200*711419e3SSlava Zakharin static bool reboxPreservesContinuity(fir::ReboxOp rebox) {
201*711419e3SSlava Zakharin   // If slicing is not involved, then the rebox does not affect
202*711419e3SSlava Zakharin   // the continuity of the array.
203*711419e3SSlava Zakharin   auto sliceArg = rebox.getSlice();
204*711419e3SSlava Zakharin   if (!sliceArg)
205*711419e3SSlava Zakharin     return true;
206*711419e3SSlava Zakharin 
207*711419e3SSlava Zakharin   // A slice with step=1 in the innermost dimension preserves
208*711419e3SSlava Zakharin   // the continuity of the array in the innermost dimension.
209*711419e3SSlava Zakharin   if (auto sliceOp =
210*711419e3SSlava Zakharin           mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp())) {
211*711419e3SSlava Zakharin     if (sliceOp.getFields().empty() && sliceOp.getSubstr().empty()) {
212*711419e3SSlava Zakharin       auto triples = sliceOp.getTriples();
213*711419e3SSlava Zakharin       if (triples.size() > 2)
214*711419e3SSlava Zakharin         if (auto innermostStep = fir::getIntIfConstant(triples[2]))
215*711419e3SSlava Zakharin           if (*innermostStep == 1)
216*711419e3SSlava Zakharin             return true;
217*711419e3SSlava Zakharin     }
218*711419e3SSlava Zakharin 
219*711419e3SSlava Zakharin     LLVM_DEBUG(llvm::dbgs()
220*711419e3SSlava Zakharin                << "REBOX with slicing may produce non-contiguous array: "
221*711419e3SSlava Zakharin                << sliceOp << '\n'
222*711419e3SSlava Zakharin                << rebox << '\n');
223*711419e3SSlava Zakharin     return false;
224*711419e3SSlava Zakharin   }
225*711419e3SSlava Zakharin 
226*711419e3SSlava Zakharin   LLVM_DEBUG(llvm::dbgs() << "REBOX with unknown slice" << sliceArg << '\n'
227*711419e3SSlava Zakharin                           << rebox << '\n');
228*711419e3SSlava Zakharin   return false;
229*711419e3SSlava Zakharin }
230*711419e3SSlava Zakharin 
2318d24b732STom Eccles /// if a value comes from a fir.rebox, follow the rebox to the original source,
2328d24b732STom Eccles /// of the value, otherwise return the value
2338d24b732STom Eccles static mlir::Value unwrapReboxOp(mlir::Value val) {
234*711419e3SSlava Zakharin   while (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>()) {
235*711419e3SSlava Zakharin     if (!reboxPreservesContinuity(rebox))
236*711419e3SSlava Zakharin       break;
2378d24b732STom Eccles     val = rebox.getBox();
238*711419e3SSlava Zakharin   }
2398d24b732STom Eccles   return val;
2408d24b732STom Eccles }
2418d24b732STom Eccles 
2428d24b732STom Eccles /// normalize a value (removing fir.declare and fir.rebox) so that we can
2438d24b732STom Eccles /// more conveniently spot values which came from function arguments
2448d24b732STom Eccles static mlir::Value normaliseVal(mlir::Value val) {
2458d24b732STom Eccles   return unwrapFirDeclare(unwrapReboxOp(val));
2468d24b732STom Eccles }
2478d24b732STom Eccles 
248ad9af7deSTom Eccles /// some FIR operations accept a fir.shape, a fir.shift or a fir.shapeshift.
249ad9af7deSTom Eccles /// fir.shift and fir.shapeshift allow us to extract lower bounds
250ad9af7deSTom Eccles /// if lowerbounds cannot be found, return nullptr
251ad9af7deSTom Eccles static mlir::Value tryGetLowerBoundsFromShapeLike(mlir::Value shapeLike,
252ad9af7deSTom Eccles                                                   unsigned dim) {
253ad9af7deSTom Eccles   mlir::Value lowerBound{nullptr};
254ad9af7deSTom Eccles   if (auto shift = shapeLike.getDefiningOp<fir::ShiftOp>())
255ad9af7deSTom Eccles     lowerBound = shift.getOrigins()[dim];
256ad9af7deSTom Eccles   if (auto shapeShift = shapeLike.getDefiningOp<fir::ShapeShiftOp>())
257ad9af7deSTom Eccles     lowerBound = shapeShift.getOrigins()[dim];
258ad9af7deSTom Eccles   return lowerBound;
259ad9af7deSTom Eccles }
260ad9af7deSTom Eccles 
261ad9af7deSTom Eccles /// attempt to get the array lower bounds of dimension dim of the memref
262ad9af7deSTom Eccles /// argument to a fir.array_coor op
263ad9af7deSTom Eccles /// 0 <= dim < rank
264ad9af7deSTom Eccles /// May return nullptr if no lower bounds can be determined
265ad9af7deSTom Eccles static mlir::Value getLowerBound(fir::ArrayCoorOp coop, unsigned dim) {
266ad9af7deSTom Eccles   // 1) try to get from the shape argument to fir.array_coor
267ad9af7deSTom Eccles   if (mlir::Value shapeLike = coop.getShape())
268ad9af7deSTom Eccles     if (mlir::Value lb = tryGetLowerBoundsFromShapeLike(shapeLike, dim))
269ad9af7deSTom Eccles       return lb;
270ad9af7deSTom Eccles 
271ad9af7deSTom Eccles   // It is important not to try to read the lower bound from the box, because
272ad9af7deSTom Eccles   // in the FIR lowering, boxes will sometimes contain incorrect lower bound
273ad9af7deSTom Eccles   // information
274ad9af7deSTom Eccles 
275ad9af7deSTom Eccles   // out of ideas
276ad9af7deSTom Eccles   return {};
277ad9af7deSTom Eccles }
278ad9af7deSTom Eccles 
279ad9af7deSTom Eccles /// gets the i'th index from array coordinate operation op
280ad9af7deSTom Eccles /// dim should range between 0 and rank - 1
281ad9af7deSTom Eccles static mlir::Value getIndex(fir::FirOpBuilder &builder, mlir::Operation *op,
282ad9af7deSTom Eccles                             unsigned dim) {
283ad9af7deSTom Eccles   if (fir::CoordinateOp coop = mlir::dyn_cast<fir::CoordinateOp>(op))
284ad9af7deSTom Eccles     return coop.getCoor()[dim];
285ad9af7deSTom Eccles 
286ad9af7deSTom Eccles   fir::ArrayCoorOp coop = mlir::dyn_cast<fir::ArrayCoorOp>(op);
287ad9af7deSTom Eccles   assert(coop &&
288ad9af7deSTom Eccles          "operation must be either fir.coordiante_of or fir.array_coor");
289ad9af7deSTom Eccles 
290ad9af7deSTom Eccles   // fir.coordinate_of indices start at 0: adjust these indices to match by
291ad9af7deSTom Eccles   // subtracting the lower bound
292ad9af7deSTom Eccles   mlir::Value index = coop.getIndices()[dim];
293ad9af7deSTom Eccles   mlir::Value lb = getLowerBound(coop, dim);
294ad9af7deSTom Eccles   if (!lb)
295ad9af7deSTom Eccles     // assume a default lower bound of one
296ad9af7deSTom Eccles     lb = builder.createIntegerConstant(coop.getLoc(), index.getType(), 1);
297ad9af7deSTom Eccles 
298ad9af7deSTom Eccles   // index_0 = index - lb;
299ad9af7deSTom Eccles   if (lb.getType() != index.getType())
300ad9af7deSTom Eccles     lb = builder.createConvert(coop.getLoc(), index.getType(), lb);
301ad9af7deSTom Eccles   return builder.create<mlir::arith::SubIOp>(coop.getLoc(), index, lb);
302ad9af7deSTom Eccles }
303ad9af7deSTom Eccles 
304a716ace1SMats Petersson void LoopVersioningPass::runOnOperation() {
305a716ace1SMats Petersson   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
306a716ace1SMats Petersson   mlir::func::FuncOp func = getOperation();
307a716ace1SMats Petersson 
308a716ace1SMats Petersson   // First look for arguments with assumed shape = unknown extent in the lowest
309a716ace1SMats Petersson   // dimension.
310a716ace1SMats Petersson   LLVM_DEBUG(llvm::dbgs() << "Func-name:" << func.getSymName() << "\n");
311a716ace1SMats Petersson   mlir::Block::BlockArgListType args = func.getArguments();
312a716ace1SMats Petersson   mlir::ModuleOp module = func->getParentOfType<mlir::ModuleOp>();
313a716ace1SMats Petersson   fir::KindMapping kindMap = fir::getKindMapping(module);
314b75f9ce3SMats Petersson   mlir::SmallVector<ArgInfo, 4> argsOfInterest;
3153e47e75fSSlava Zakharin   std::optional<mlir::DataLayout> dl =
3163e47e75fSSlava Zakharin       fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
3173e47e75fSSlava Zakharin   if (!dl)
3183e47e75fSSlava Zakharin     mlir::emitError(module.getLoc(),
3193e47e75fSSlava Zakharin                     "data layout attribute is required to perform " DEBUG_TYPE
3203e47e75fSSlava Zakharin                     "pass");
321a716ace1SMats Petersson   for (auto &arg : args) {
322cccf4d6eSSlava Zakharin     // Optional arguments must be checked for IsPresent before
323cccf4d6eSSlava Zakharin     // looking for the bounds. They are unsupported for the time being.
324cccf4d6eSSlava Zakharin     if (func.getArgAttrOfType<mlir::UnitAttr>(arg.getArgNumber(),
325cccf4d6eSSlava Zakharin                                               fir::getOptionalAttrName())) {
326cccf4d6eSSlava Zakharin       LLVM_DEBUG(llvm::dbgs() << "OPTIONAL is not supported\n");
327cccf4d6eSSlava Zakharin       continue;
328cccf4d6eSSlava Zakharin     }
329cccf4d6eSSlava Zakharin 
330*711419e3SSlava Zakharin     auto [rank, typeSize] =
331*711419e3SSlava Zakharin         getRankAndElementSize(kindMap, *dl, arg, /*isArgument=*/true);
332*711419e3SSlava Zakharin     if (rank != 0 && typeSize != 0)
3338d24b732STom Eccles       argsOfInterest.push_back({arg, typeSize, rank, {}});
334a716ace1SMats Petersson   }
335a716ace1SMats Petersson 
3367beb65aeSSlava Zakharin   if (argsOfInterest.empty()) {
3377beb65aeSSlava Zakharin     LLVM_DEBUG(llvm::dbgs()
3387beb65aeSSlava Zakharin                << "No suitable arguments.\n=== End " DEBUG_TYPE " ===\n");
339ad9af7deSTom Eccles     return;
340ad9af7deSTom Eccles   }
341ad9af7deSTom Eccles 
3427beb65aeSSlava Zakharin   // A list of all loops in the function in post-order.
3437beb65aeSSlava Zakharin   mlir::SmallVector<fir::DoLoopOp> originalLoops;
3447beb65aeSSlava Zakharin   // Information about the arguments usage by the instructions
3457beb65aeSSlava Zakharin   // immediately nested in a loop.
3467beb65aeSSlava Zakharin   llvm::DenseMap<fir::DoLoopOp, ArgsUsageInLoop> argsInLoops;
3477beb65aeSSlava Zakharin 
3488dcee580SMats Petersson   auto &domInfo = getAnalysis<mlir::DominanceInfo>();
3498dcee580SMats Petersson 
3507beb65aeSSlava Zakharin   // Traverse the loops in post-order and see
3517beb65aeSSlava Zakharin   // if those arguments are used inside any loop.
3527beb65aeSSlava Zakharin   func.walk([&](fir::DoLoopOp loop) {
3537beb65aeSSlava Zakharin     mlir::Block &body = *loop.getBody();
3547beb65aeSSlava Zakharin     auto &argsInLoop = argsInLoops[loop];
3557beb65aeSSlava Zakharin     originalLoops.push_back(loop);
3567beb65aeSSlava Zakharin     body.walk([&](mlir::Operation *op) {
3577beb65aeSSlava Zakharin       // Support either fir.array_coor or fir.coordinate_of.
3587beb65aeSSlava Zakharin       if (!mlir::isa<fir::ArrayCoorOp, fir::CoordinateOp>(op))
3597beb65aeSSlava Zakharin         return;
3607beb65aeSSlava Zakharin       // Process only operations immediately nested in the current loop.
361a716ace1SMats Petersson       if (op->getParentOfType<fir::DoLoopOp>() != loop)
362a716ace1SMats Petersson         return;
36305011024STom Eccles       mlir::Value operand = op->getOperand(0);
364a716ace1SMats Petersson       for (auto a : argsOfInterest) {
3658d24b732STom Eccles         if (a.arg == normaliseVal(operand)) {
3667beb65aeSSlava Zakharin           // Use the reboxed value, not the block arg when re-creating the loop.
3678d24b732STom Eccles           a.arg = operand;
368a716ace1SMats Petersson 
3698dcee580SMats Petersson           // Check that the operand dominates the loop?
3708dcee580SMats Petersson           // If this is the case, record such operands in argsInLoop.cannot-
3718dcee580SMats Petersson           // Transform, so that they disable the transformation for the parent
3728dcee580SMats Petersson           /// loops as well.
3738dcee580SMats Petersson           if (!domInfo.dominates(a.arg, loop))
3748dcee580SMats Petersson             argsInLoop.cannotTransform.insert(a.arg);
3758dcee580SMats Petersson 
3767beb65aeSSlava Zakharin           // No support currently for sliced arrays.
3777beb65aeSSlava Zakharin           // This means that we cannot transform properly
3787beb65aeSSlava Zakharin           // instructions referencing a.arg in the whole loop
3797beb65aeSSlava Zakharin           // nest this loop is located in.
3807beb65aeSSlava Zakharin           if (auto arrayCoor = mlir::dyn_cast<fir::ArrayCoorOp>(op))
3817beb65aeSSlava Zakharin             if (arrayCoor.getSlice())
3827beb65aeSSlava Zakharin               argsInLoop.cannotTransform.insert(a.arg);
3837beb65aeSSlava Zakharin 
384*711419e3SSlava Zakharin           // We need to compute the rank and element size
385*711419e3SSlava Zakharin           // based on the operand, not the original argument,
386*711419e3SSlava Zakharin           // because array slicing may affect it.
387*711419e3SSlava Zakharin           std::tie(a.rank, a.size) = getRankAndElementSize(kindMap, *dl, a.arg);
388*711419e3SSlava Zakharin           if (a.rank == 0 || a.size == 0)
389*711419e3SSlava Zakharin             argsInLoop.cannotTransform.insert(a.arg);
390*711419e3SSlava Zakharin 
3917beb65aeSSlava Zakharin           if (argsInLoop.cannotTransform.contains(a.arg)) {
3927beb65aeSSlava Zakharin             // Remove any previously recorded usage, if any.
3937beb65aeSSlava Zakharin             argsInLoop.usageInfo.erase(a.arg);
3947beb65aeSSlava Zakharin             break;
3957beb65aeSSlava Zakharin           }
3967beb65aeSSlava Zakharin 
3977beb65aeSSlava Zakharin           // Record the a.arg usage, if not recorded yet.
3987beb65aeSSlava Zakharin           argsInLoop.usageInfo.try_emplace(a.arg, a);
399a716ace1SMats Petersson           break;
400a716ace1SMats Petersson         }
401a716ace1SMats Petersson       }
4027beb65aeSSlava Zakharin     });
403a716ace1SMats Petersson   });
404a716ace1SMats Petersson 
4057beb65aeSSlava Zakharin   // Dump loops info after initial collection.
4064056287dSDavid Green   LLVM_DEBUG({
4077beb65aeSSlava Zakharin     llvm::dbgs() << "Initial usage info:\n";
4087beb65aeSSlava Zakharin     for (fir::DoLoopOp loop : originalLoops) {
4097beb65aeSSlava Zakharin       auto &argsInLoop = argsInLoops[loop];
4107beb65aeSSlava Zakharin       argsInLoop.dump(loop);
411a716ace1SMats Petersson     }
4124056287dSDavid Green   });
4137beb65aeSSlava Zakharin 
4147beb65aeSSlava Zakharin   // Clear argument usage for parent loops if an inner loop
4157beb65aeSSlava Zakharin   // contains a non-transformable usage.
4167beb65aeSSlava Zakharin   for (fir::DoLoopOp loop : originalLoops) {
4177beb65aeSSlava Zakharin     auto &argsInLoop = argsInLoops[loop];
4187beb65aeSSlava Zakharin     if (argsInLoop.cannotTransform.empty())
4197beb65aeSSlava Zakharin       continue;
4207beb65aeSSlava Zakharin 
4217beb65aeSSlava Zakharin     fir::DoLoopOp parent = loop;
4227beb65aeSSlava Zakharin     while ((parent = parent->getParentOfType<fir::DoLoopOp>()))
4237beb65aeSSlava Zakharin       argsInLoops[parent].eraseUsage(argsInLoop.cannotTransform);
4247beb65aeSSlava Zakharin   }
4257beb65aeSSlava Zakharin 
4267beb65aeSSlava Zakharin   // If an argument access can be optimized in a loop and
4277beb65aeSSlava Zakharin   // its descendant loop, then it does not make sense to
4287beb65aeSSlava Zakharin   // generate the contiguity check for the descendant loop.
4297beb65aeSSlava Zakharin   // The check will be produced as part of the ancestor
4307beb65aeSSlava Zakharin   // loop's transformation. So we can clear the argument
4317beb65aeSSlava Zakharin   // usage for all descendant loops.
4327beb65aeSSlava Zakharin   for (fir::DoLoopOp loop : originalLoops) {
4337beb65aeSSlava Zakharin     auto &argsInLoop = argsInLoops[loop];
4347beb65aeSSlava Zakharin     if (argsInLoop.usageInfo.empty())
4357beb65aeSSlava Zakharin       continue;
4367beb65aeSSlava Zakharin 
4377beb65aeSSlava Zakharin     loop.getBody()->walk([&](fir::DoLoopOp dloop) {
4387beb65aeSSlava Zakharin       argsInLoops[dloop].eraseUsage(argsInLoop.usageInfo);
439a716ace1SMats Petersson     });
4407beb65aeSSlava Zakharin   }
4417beb65aeSSlava Zakharin 
4424056287dSDavid Green   LLVM_DEBUG({
4437beb65aeSSlava Zakharin     llvm::dbgs() << "Final usage info:\n";
4447beb65aeSSlava Zakharin     for (fir::DoLoopOp loop : originalLoops) {
4457beb65aeSSlava Zakharin       auto &argsInLoop = argsInLoops[loop];
4467beb65aeSSlava Zakharin       argsInLoop.dump(loop);
4477beb65aeSSlava Zakharin     }
4484056287dSDavid Green   });
4497beb65aeSSlava Zakharin 
4507beb65aeSSlava Zakharin   // Reduce the collected information to a list of loops
4517beb65aeSSlava Zakharin   // with attached arguments usage information.
4527beb65aeSSlava Zakharin   // The list must hold the loops in post order, so that
4537beb65aeSSlava Zakharin   // the inner loops are transformed before the outer loops.
4547beb65aeSSlava Zakharin   struct OpsWithArgs {
4557beb65aeSSlava Zakharin     mlir::Operation *op;
4567beb65aeSSlava Zakharin     mlir::SmallVector<ArgInfo, 4> argsAndDims;
4577beb65aeSSlava Zakharin   };
4587beb65aeSSlava Zakharin   mlir::SmallVector<OpsWithArgs, 4> loopsOfInterest;
4597beb65aeSSlava Zakharin   for (fir::DoLoopOp loop : originalLoops) {
4607beb65aeSSlava Zakharin     auto &argsInLoop = argsInLoops[loop];
4617beb65aeSSlava Zakharin     if (argsInLoop.usageInfo.empty())
4627beb65aeSSlava Zakharin       continue;
4637beb65aeSSlava Zakharin     OpsWithArgs info;
4647beb65aeSSlava Zakharin     info.op = loop;
4657beb65aeSSlava Zakharin     for (auto &arg : argsInLoop.usageInfo)
4667beb65aeSSlava Zakharin       info.argsAndDims.push_back(arg.second);
4677beb65aeSSlava Zakharin     loopsOfInterest.emplace_back(std::move(info));
4687beb65aeSSlava Zakharin   }
4697beb65aeSSlava Zakharin 
4707beb65aeSSlava Zakharin   if (loopsOfInterest.empty()) {
4717beb65aeSSlava Zakharin     LLVM_DEBUG(llvm::dbgs()
4727beb65aeSSlava Zakharin                << "No loops to transform.\n=== End " DEBUG_TYPE " ===\n");
473a716ace1SMats Petersson     return;
4747beb65aeSSlava Zakharin   }
475a716ace1SMats Petersson 
476a716ace1SMats Petersson   // If we get here, there are loops to process.
47753cc33b0STom Eccles   fir::FirOpBuilder builder{module, std::move(kindMap)};
478a716ace1SMats Petersson   mlir::Location loc = builder.getUnknownLoc();
479a716ace1SMats Petersson   mlir::IndexType idxTy = builder.getIndexType();
480a716ace1SMats Petersson 
481*711419e3SSlava Zakharin   LLVM_DEBUG(llvm::dbgs() << "Func Before transformation:\n");
482*711419e3SSlava Zakharin   LLVM_DEBUG(func->dump());
483a716ace1SMats Petersson 
484a716ace1SMats Petersson   LLVM_DEBUG(llvm::dbgs() << "loopsOfInterest: " << loopsOfInterest.size()
485a716ace1SMats Petersson                           << "\n");
486a716ace1SMats Petersson   for (auto op : loopsOfInterest) {
487a716ace1SMats Petersson     LLVM_DEBUG(op.op->dump());
488a716ace1SMats Petersson     builder.setInsertionPoint(op.op);
489a716ace1SMats Petersson 
490a716ace1SMats Petersson     mlir::Value allCompares = nullptr;
491a716ace1SMats Petersson     // Ensure all of the arrays are unit-stride.
492a716ace1SMats Petersson     for (auto &arg : op.argsAndDims) {
493b75f9ce3SMats Petersson       // Fetch all the dimensions of the array, except the last dimension.
494b75f9ce3SMats Petersson       // Always fetch the first dimension, however, so set ndims = 1 if
495b75f9ce3SMats Petersson       // we have one dim
496b75f9ce3SMats Petersson       unsigned ndims = arg.rank;
497b75f9ce3SMats Petersson       for (unsigned i = 0; i < ndims; i++) {
498a716ace1SMats Petersson         mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
499a716ace1SMats Petersson         arg.dims[i] = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
5008d24b732STom Eccles                                                      arg.arg, dimIdx);
501a716ace1SMats Petersson       }
502b75f9ce3SMats Petersson       // We only care about lowest order dimension, here.
503a716ace1SMats Petersson       mlir::Value elemSize =
504a716ace1SMats Petersson           builder.createIntegerConstant(loc, idxTy, arg.size);
505a716ace1SMats Petersson       mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
506a716ace1SMats Petersson           loc, mlir::arith::CmpIPredicate::eq, arg.dims[0].getResult(2),
507a716ace1SMats Petersson           elemSize);
508a716ace1SMats Petersson       if (!allCompares) {
509a716ace1SMats Petersson         allCompares = cmp;
510a716ace1SMats Petersson       } else {
511a716ace1SMats Petersson         allCompares =
512a716ace1SMats Petersson             builder.create<mlir::arith::AndIOp>(loc, cmp, allCompares);
513a716ace1SMats Petersson       }
514a716ace1SMats Petersson     }
515a716ace1SMats Petersson 
516a716ace1SMats Petersson     auto ifOp =
517a716ace1SMats Petersson         builder.create<fir::IfOp>(loc, op.op->getResultTypes(), allCompares,
518a716ace1SMats Petersson                                   /*withElse=*/true);
519a716ace1SMats Petersson     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
520a716ace1SMats Petersson 
521a716ace1SMats Petersson     LLVM_DEBUG(llvm::dbgs() << "Creating cloned loop\n");
522a716ace1SMats Petersson     mlir::Operation *clonedLoop = op.op->clone();
523a716ace1SMats Petersson     bool changed = false;
524a716ace1SMats Petersson     for (auto &arg : op.argsAndDims) {
525a716ace1SMats Petersson       fir::SequenceType::Shape newShape;
526a716ace1SMats Petersson       newShape.push_back(fir::SequenceType::getUnknownExtent());
5278d24b732STom Eccles       auto elementType = fir::unwrapSeqOrBoxedSeqType(arg.arg.getType());
528a716ace1SMats Petersson       mlir::Type arrTy = fir::SequenceType::get(newShape, elementType);
529a716ace1SMats Petersson       mlir::Type boxArrTy = fir::BoxType::get(arrTy);
530a716ace1SMats Petersson       mlir::Type refArrTy = builder.getRefType(arrTy);
5318d24b732STom Eccles       auto carg = builder.create<fir::ConvertOp>(loc, boxArrTy, arg.arg);
532a716ace1SMats Petersson       auto caddr = builder.create<fir::BoxAddrOp>(loc, refArrTy, carg);
533a716ace1SMats Petersson       auto insPt = builder.saveInsertionPoint();
534a716ace1SMats Petersson       // Use caddr instead of arg.
535ad9af7deSTom Eccles       clonedLoop->walk([&](mlir::Operation *coop) {
536ad9af7deSTom Eccles         if (!mlir::isa<fir::CoordinateOp, fir::ArrayCoorOp>(coop))
537ad9af7deSTom Eccles           return;
538a716ace1SMats Petersson         // Reduce the multi-dimensioned index to a single index.
539a716ace1SMats Petersson         // This is required becase fir arrays do not support multiple dimensions
540a716ace1SMats Petersson         // with unknown dimensions at compile time.
541b75f9ce3SMats Petersson         // We then calculate the multidimensional array like this:
542b75f9ce3SMats Petersson         // arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x)
543b75f9ce3SMats Petersson         // where stride is the distance between elements in the dimensions
544b75f9ce3SMats Petersson         // 0, 1 and 2 or x, y and z.
5458d24b732STom Eccles         if (coop->getOperand(0) == arg.arg && coop->getOperands().size() >= 2) {
546a716ace1SMats Petersson           builder.setInsertionPoint(coop);
547b75f9ce3SMats Petersson           mlir::Value totalIndex;
548b75f9ce3SMats Petersson           for (unsigned i = arg.rank - 1; i > 0; i--) {
549a716ace1SMats Petersson             mlir::Value curIndex =
550ad9af7deSTom Eccles                 builder.createConvert(loc, idxTy, getIndex(builder, coop, i));
551b75f9ce3SMats Petersson             // Multiply by the stride of this array. Later we'll divide by the
552b75f9ce3SMats Petersson             // element size.
553a716ace1SMats Petersson             mlir::Value scale =
554b75f9ce3SMats Petersson                 builder.createConvert(loc, idxTy, arg.dims[i].getResult(2));
555b75f9ce3SMats Petersson             curIndex =
556b75f9ce3SMats Petersson                 builder.create<mlir::arith::MulIOp>(loc, scale, curIndex);
557b75f9ce3SMats Petersson             totalIndex = (totalIndex) ? builder.create<mlir::arith::AddIOp>(
558b75f9ce3SMats Petersson                                             loc, curIndex, totalIndex)
559b75f9ce3SMats Petersson                                       : curIndex;
560a716ace1SMats Petersson           }
561b75f9ce3SMats Petersson           // This is the lowest dimension - which doesn't need scaling
562b75f9ce3SMats Petersson           mlir::Value finalIndex =
563ad9af7deSTom Eccles               builder.createConvert(loc, idxTy, getIndex(builder, coop, 0));
564b75f9ce3SMats Petersson           if (totalIndex) {
565b812932bSMats Petersson             assert(llvm::isPowerOf2_32(arg.size) &&
566b812932bSMats Petersson                    "Expected power of two here");
567b812932bSMats Petersson             unsigned bits = llvm::Log2_32(arg.size);
568b812932bSMats Petersson             mlir::Value elemShift =
569b812932bSMats Petersson                 builder.createIntegerConstant(loc, idxTy, bits);
570a716ace1SMats Petersson             totalIndex = builder.create<mlir::arith::AddIOp>(
571b75f9ce3SMats Petersson                 loc,
572b812932bSMats Petersson                 builder.create<mlir::arith::ShRSIOp>(loc, totalIndex,
573b812932bSMats Petersson                                                      elemShift),
574b75f9ce3SMats Petersson                 finalIndex);
575b75f9ce3SMats Petersson           } else {
576b75f9ce3SMats Petersson             totalIndex = finalIndex;
577b75f9ce3SMats Petersson           }
578a716ace1SMats Petersson           auto newOp = builder.create<fir::CoordinateOp>(
579a716ace1SMats Petersson               loc, builder.getRefType(elementType), caddr,
580a716ace1SMats Petersson               mlir::ValueRange{totalIndex});
581a716ace1SMats Petersson           LLVM_DEBUG(newOp->dump());
582a716ace1SMats Petersson           coop->getResult(0).replaceAllUsesWith(newOp->getResult(0));
583a716ace1SMats Petersson           coop->erase();
584a716ace1SMats Petersson           changed = true;
585a716ace1SMats Petersson         }
586a716ace1SMats Petersson       });
587a716ace1SMats Petersson 
588a716ace1SMats Petersson       builder.restoreInsertionPoint(insPt);
589a716ace1SMats Petersson     }
590a716ace1SMats Petersson     assert(changed && "Expected operations to have changed");
591a716ace1SMats Petersson 
592a716ace1SMats Petersson     builder.insert(clonedLoop);
593a716ace1SMats Petersson     // Forward the result(s), if any, from the loop operation to the
594a716ace1SMats Petersson     //
595a716ace1SMats Petersson     mlir::ResultRange results = clonedLoop->getResults();
596a716ace1SMats Petersson     bool hasResults = (results.size() > 0);
597a716ace1SMats Petersson     if (hasResults)
598a716ace1SMats Petersson       builder.create<fir::ResultOp>(loc, results);
599a716ace1SMats Petersson 
600a716ace1SMats Petersson     // Add the original loop in the else-side of the if operation.
601a716ace1SMats Petersson     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
60249212d16SDavid Green     op.op->replaceAllUsesWith(ifOp);
603a716ace1SMats Petersson     op.op->remove();
604a716ace1SMats Petersson     builder.insert(op.op);
605a716ace1SMats Petersson     // Rely on "cloned loop has results, so original loop also has results".
606a716ace1SMats Petersson     if (hasResults) {
607a716ace1SMats Petersson       builder.create<fir::ResultOp>(loc, op.op->getResults());
608a716ace1SMats Petersson     } else {
609a716ace1SMats Petersson       // Use an assert to check this.
610a716ace1SMats Petersson       assert(op.op->getResults().size() == 0 &&
611a716ace1SMats Petersson              "Weird, the cloned loop doesn't have results, but the original "
612a716ace1SMats Petersson              "does?");
613a716ace1SMats Petersson     }
614a716ace1SMats Petersson   }
615a716ace1SMats Petersson 
616*711419e3SSlava Zakharin   LLVM_DEBUG(llvm::dbgs() << "Func After transform:\n");
617*711419e3SSlava Zakharin   LLVM_DEBUG(func->dump());
618a716ace1SMats Petersson 
619a716ace1SMats Petersson   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
620a716ace1SMats Petersson }
621