xref: /llvm-project/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp (revision 906784a399069f0829b28273c027fbe36be40ce9)
1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/Array.h"
11 #include "flang/Optimizer/Builder/BoxValue.h"
12 #include "flang/Optimizer/Builder/FIRBuilder.h"
13 #include "flang/Optimizer/Builder/Factory.h"
14 #include "flang/Optimizer/Builder/Runtime/Derived.h"
15 #include "flang/Optimizer/Builder/Todo.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
18 #include "flang/Optimizer/Support/FIRContext.h"
19 #include "flang/Optimizer/Transforms/Passes.h"
20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "flang-array-value-copy"
26 
27 using namespace fir;
28 using namespace mlir;
29 
30 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
31 
32 namespace {
33 
34 /// Array copy analysis.
35 /// Perform an interference analysis between array values.
36 ///
37 /// Lowering will generate a sequence of the following form.
38 /// ```mlir
39 ///   %a_1 = fir.array_load %array_1(%shape) : ...
40 ///   ...
41 ///   %a_j = fir.array_load %array_j(%shape) : ...
42 ///   ...
43 ///   %a_n = fir.array_load %array_n(%shape) : ...
44 ///     ...
45 ///     %v_i = fir.array_fetch %a_i, ...
46 ///     %a_j1 = fir.array_update %a_j, ...
47 ///     ...
48 ///   fir.array_merge_store %a_j, %a_jn to %array_j : ...
49 /// ```
50 ///
51 /// The analysis is to determine if there are any conflicts. A conflict is when
52 /// one the following cases occurs.
53 ///
54 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
55 /// loaded from the same array memory reference (array_j) but with a different
56 /// shape as the other array values a_i, where i != j. [Possible overlapping
57 /// arrays.]
58 ///
59 /// 2. There is either an array_fetch or array_update of a_j with a different
60 /// set of index values. [Possible loop-carried dependence.]
61 ///
62 /// If none of the array values overlap in storage and the accesses are not
63 /// loop-carried, then the arrays are conflict-free and no copies are required.
64 class ArrayCopyAnalysis {
65 public:
66   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis)
67 
68   using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
69   using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
70   using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>;
71   using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>;
72 
73   ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); }
74 
75   mlir::Operation *getOperation() const { return operation; }
76 
77   /// Return true iff the `array_merge_store` has potential conflicts.
78   bool hasPotentialConflict(mlir::Operation *op) const {
79     LLVM_DEBUG(llvm::dbgs()
80                << "looking for a conflict on " << *op
81                << " and the set has a total of " << conflicts.size() << '\n');
82     return conflicts.contains(op);
83   }
84 
85   /// Return the use map.
86   /// The use map maps array access, amend, fetch and update operations back to
87   /// the array load that is the original source of the array value.
88   /// It maps an array_load to an array_merge_store, if and only if the loaded
89   /// array value has pending modifications to be merged.
90   const OperationUseMapT &getUseMap() const { return useMap; }
91 
92   /// Return the set of array_access ops directly associated with array_amend
93   /// ops.
94   bool inAmendAccessSet(mlir::Operation *op) const {
95     return amendAccesses.count(op);
96   }
97 
98   /// For ArrayLoad `load`, return the transitive set of all OpOperands.
99   UseSetT getLoadUseSet(mlir::Operation *load) const {
100     assert(loadMapSets.count(load) && "analysis missed an array load?");
101     return loadMapSets.lookup(load);
102   }
103 
104   void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions,
105                      ArrayLoadOp load);
106 
107 private:
108   void construct(mlir::Operation *topLevelOp);
109 
110   mlir::Operation *operation; // operation that analysis ran upon
111   ConflictSetT conflicts;     // set of conflicts (loads and merge stores)
112   OperationUseMapT useMap;
113   LoadMapSetsT loadMapSets;
114   // Set of array_access ops associated with array_amend ops.
115   AmendAccessSetT amendAccesses;
116 };
117 } // namespace
118 
119 namespace {
120 /// Helper class to collect all array operations that produced an array value.
121 class ReachCollector {
122 public:
123   ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
124                  mlir::Region *loopRegion)
125       : reach{reach}, loopRegion{loopRegion} {}
126 
127   void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) {
128     if (range.empty()) {
129       collectArrayMentionFrom(op, mlir::Value{});
130       return;
131     }
132     for (mlir::Value v : range)
133       collectArrayMentionFrom(v);
134   }
135 
136   // Collect all the array_access ops in `block`. This recursively looks into
137   // blocks in ops with regions.
138   // FIXME: This is temporarily relying on the array_amend appearing in a
139   // do_loop Region.  This phase ordering assumption can be eliminated by using
140   // dominance information to find the array_access ops or by scanning the
141   // transitive closure of the amending array_access's users and the defs that
142   // reach them.
143   void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result,
144                        mlir::Block *block) {
145     for (auto &op : *block) {
146       if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) {
147         LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n');
148         result.push_back(access);
149         continue;
150       }
151       for (auto &region : op.getRegions())
152         for (auto &bb : region.getBlocks())
153           collectAccesses(result, &bb);
154     }
155   }
156 
157   void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) {
158     // `val` is defined by an Op, process the defining Op.
159     // If `val` is defined by a region containing Op, we want to drill down
160     // and through that Op's region(s).
161     LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
162     auto popFn = [&](auto rop) {
163       assert(val && "op must have a result value");
164       auto resNum = val.cast<mlir::OpResult>().getResultNumber();
165       llvm::SmallVector<mlir::Value> results;
166       rop.resultToSourceOps(results, resNum);
167       for (auto u : results)
168         collectArrayMentionFrom(u);
169     };
170     if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) {
171       popFn(rop);
172       return;
173     }
174     if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) {
175       popFn(rop);
176       return;
177     }
178     if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
179       popFn(rop);
180       return;
181     }
182     if (auto box = mlir::dyn_cast<EmboxOp>(op)) {
183       for (auto *user : box.getMemref().getUsers())
184         if (user != op)
185           collectArrayMentionFrom(user, user->getResults());
186       return;
187     }
188     if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
189       if (opIsInsideLoops(mergeStore))
190         collectArrayMentionFrom(mergeStore.getSequence());
191       return;
192     }
193 
194     if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
195       // Look for any stores inside the loops, and collect an array operation
196       // that produced the value being stored to it.
197       for (auto *user : op->getUsers())
198         if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
199           if (opIsInsideLoops(store))
200             collectArrayMentionFrom(store.getValue());
201       return;
202     }
203 
204     // Scan the uses of amend's memref
205     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) {
206       reach.push_back(op);
207       llvm::SmallVector<ArrayAccessOp> accesses;
208       collectAccesses(accesses, op->getBlock());
209       for (auto access : accesses)
210         collectArrayMentionFrom(access.getResult());
211     }
212 
213     // Otherwise, Op does not contain a region so just chase its operands.
214     if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp,
215                   ArrayFetchOp>(op)) {
216       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
217       reach.push_back(op);
218     }
219 
220     // Include all array_access ops using an array_load.
221     if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op))
222       for (auto *user : arrLd.getResult().getUsers())
223         if (mlir::isa<ArrayAccessOp>(user)) {
224           LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n");
225           reach.push_back(user);
226         }
227 
228     // Array modify assignment is performed on the result. So the analysis must
229     // look at the what is done with the result.
230     if (mlir::isa<ArrayModifyOp>(op))
231       for (auto *user : op->getResult(0).getUsers())
232         followUsers(user);
233 
234     if (mlir::isa<fir::CallOp>(op)) {
235       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
236       reach.push_back(op);
237     }
238 
239     for (auto u : op->getOperands())
240       collectArrayMentionFrom(u);
241   }
242 
243   void collectArrayMentionFrom(mlir::BlockArgument ba) {
244     auto *parent = ba.getOwner()->getParentOp();
245     // If inside an Op holding a region, the block argument corresponds to an
246     // argument passed to the containing Op.
247     auto popFn = [&](auto rop) {
248       collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
249     };
250     if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
251       popFn(rop);
252       return;
253     }
254     if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
255       popFn(rop);
256       return;
257     }
258     // Otherwise, a block argument is provided via the pred blocks.
259     for (auto *pred : ba.getOwner()->getPredecessors()) {
260       auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
261       collectArrayMentionFrom(u);
262     }
263   }
264 
265   // Recursively trace operands to find all array operations relating to the
266   // values merged.
267   void collectArrayMentionFrom(mlir::Value val) {
268     if (!val || visited.contains(val))
269       return;
270     visited.insert(val);
271 
272     // Process a block argument.
273     if (auto ba = val.dyn_cast<mlir::BlockArgument>()) {
274       collectArrayMentionFrom(ba);
275       return;
276     }
277 
278     // Process an Op.
279     if (auto *op = val.getDefiningOp()) {
280       collectArrayMentionFrom(op, val);
281       return;
282     }
283 
284     emitFatalError(val.getLoc(), "unhandled value");
285   }
286 
287   /// Return all ops that produce the array value that is stored into the
288   /// `array_merge_store`.
289   static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
290                              mlir::Value seq) {
291     reach.clear();
292     mlir::Region *loopRegion = nullptr;
293     if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp()))
294       loopRegion = &doLoop->getRegion(0);
295     ReachCollector collector(reach, loopRegion);
296     collector.collectArrayMentionFrom(seq);
297   }
298 
299 private:
300   /// Is \op inside the loop nest region ?
301   /// FIXME: replace this structural dependence with graph properties.
302   bool opIsInsideLoops(mlir::Operation *op) const {
303     auto *region = op->getParentRegion();
304     while (region) {
305       if (region == loopRegion)
306         return true;
307       region = region->getParentRegion();
308     }
309     return false;
310   }
311 
312   /// Recursively trace the use of an operation results, calling
313   /// collectArrayMentionFrom on the direct and indirect user operands.
314   void followUsers(mlir::Operation *op) {
315     for (auto userOperand : op->getOperands())
316       collectArrayMentionFrom(userOperand);
317     // Go through potential converts/coordinate_op.
318     for (auto indirectUser : op->getUsers())
319       followUsers(indirectUser);
320   }
321 
322   llvm::SmallVectorImpl<mlir::Operation *> &reach;
323   llvm::SmallPtrSet<mlir::Value, 16> visited;
324   /// Region of the loops nest that produced the array value.
325   mlir::Region *loopRegion;
326 };
327 } // namespace
328 
329 /// Find all the array operations that access the array value that is loaded by
330 /// the array load operation, `load`.
331 void ArrayCopyAnalysis::arrayMentions(
332     llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) {
333   mentions.clear();
334   auto lmIter = loadMapSets.find(load);
335   if (lmIter != loadMapSets.end()) {
336     for (auto *opnd : lmIter->second) {
337       auto *owner = opnd->getOwner();
338       if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
339                     ArrayModifyOp>(owner))
340         mentions.push_back(owner);
341     }
342     return;
343   }
344 
345   UseSetT visited;
346   llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
347 
348   auto appendToQueue = [&](mlir::Value val) {
349     for (auto &use : val.getUses())
350       if (!visited.count(&use)) {
351         visited.insert(&use);
352         queue.push_back(&use);
353       }
354   };
355 
356   // Build the set of uses of `original`.
357   // let USES = { uses of original fir.load }
358   appendToQueue(load);
359 
360   // Process the worklist until done.
361   while (!queue.empty()) {
362     mlir::OpOperand *operand = queue.pop_back_val();
363     mlir::Operation *owner = operand->getOwner();
364     if (!owner)
365       continue;
366     auto structuredLoop = [&](auto ro) {
367       if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
368         int64_t arg = blockArg.getArgNumber();
369         mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1);
370         appendToQueue(output);
371         appendToQueue(blockArg);
372       }
373     };
374     // TODO: this need to be updated to use the control-flow interface.
375     auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
376       if (operands.empty())
377         return;
378 
379       // Check if this operand is within the range.
380       unsigned operandIndex = operand->getOperandNumber();
381       unsigned operandsStart = operands.getBeginOperandIndex();
382       if (operandIndex < operandsStart ||
383           operandIndex >= (operandsStart + operands.size()))
384         return;
385 
386       // Index the successor.
387       unsigned argIndex = operandIndex - operandsStart;
388       appendToQueue(dest->getArgument(argIndex));
389     };
390     // Thread uses into structured loop bodies and return value uses.
391     if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
392       structuredLoop(ro);
393     } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
394       structuredLoop(ro);
395     } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
396       // Thread any uses of fir.if that return the marked array value.
397       mlir::Operation *parent = rs->getParentRegion()->getParentOp();
398       if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent))
399         appendToQueue(ifOp.getResult(operand->getOperandNumber()));
400     } else if (mlir::isa<ArrayFetchOp>(owner)) {
401       // Keep track of array value fetches.
402       LLVM_DEBUG(llvm::dbgs()
403                  << "add fetch {" << *owner << "} to array value set\n");
404       mentions.push_back(owner);
405     } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
406       // Keep track of array value updates and thread the return value uses.
407       LLVM_DEBUG(llvm::dbgs()
408                  << "add update {" << *owner << "} to array value set\n");
409       mentions.push_back(owner);
410       appendToQueue(update.getResult());
411     } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
412       // Keep track of array value modification and thread the return value
413       // uses.
414       LLVM_DEBUG(llvm::dbgs()
415                  << "add modify {" << *owner << "} to array value set\n");
416       mentions.push_back(owner);
417       appendToQueue(update.getResult(1));
418     } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) {
419       mentions.push_back(owner);
420     } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) {
421       mentions.push_back(owner);
422       appendToQueue(amend.getResult());
423     } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
424       branchOp(br.getDest(), br.getDestOperands());
425     } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
426       branchOp(br.getTrueDest(), br.getTrueOperands());
427       branchOp(br.getFalseDest(), br.getFalseOperands());
428     } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
429       // do nothing
430     } else {
431       llvm::report_fatal_error("array value reached unexpected op");
432     }
433   }
434   loadMapSets.insert({load, visited});
435 }
436 
437 static bool hasPointerType(mlir::Type type) {
438   if (auto boxTy = type.dyn_cast<BoxType>())
439     type = boxTy.getEleTy();
440   return type.isa<fir::PointerType>();
441 }
442 
443 // This is a NF performance hack. It makes a simple test that the slices of the
444 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive.
445 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) {
446   // If the same array_load, then no further testing is warranted.
447   if (ld.getResult() == st.getOriginal())
448     return false;
449 
450   auto getSliceOp = [](mlir::Value val) -> SliceOp {
451     if (!val)
452       return {};
453     auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp());
454     if (!sliceOp)
455       return {};
456     return sliceOp;
457   };
458 
459   auto ldSlice = getSliceOp(ld.getSlice());
460   auto stSlice = getSliceOp(st.getSlice());
461   if (!ldSlice || !stSlice)
462     return false;
463 
464   // Resign on subobject slices.
465   if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() ||
466       !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty())
467     return false;
468 
469   // Crudely test that the two slices do not overlap by looking for the
470   // following general condition. If the slices look like (i:j) and (j+1:k) then
471   // these ranges do not overlap. The addend must be a constant.
472   auto ldTriples = ldSlice.getTriples();
473   auto stTriples = stSlice.getTriples();
474   const auto size = ldTriples.size();
475   if (size != stTriples.size())
476     return false;
477 
478   auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
479     auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
480       auto *op = v.getDefiningOp();
481       while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op))
482         op = conv.getValue().getDefiningOp();
483       return op;
484     };
485 
486     auto isPositiveConstant = [](mlir::Value v) -> bool {
487       if (auto conOp =
488               mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
489         if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>())
490           return iattr.getInt() > 0;
491       return false;
492     };
493 
494     auto *op1 = removeConvert(v1);
495     auto *op2 = removeConvert(v2);
496     if (!op1 || !op2)
497       return false;
498     if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
499       if ((addi.getLhs().getDefiningOp() == op1 &&
500            isPositiveConstant(addi.getRhs())) ||
501           (addi.getRhs().getDefiningOp() == op1 &&
502            isPositiveConstant(addi.getLhs())))
503         return true;
504     if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
505       if (subi.getLhs().getDefiningOp() == op2 &&
506           isPositiveConstant(subi.getRhs()))
507         return true;
508     return false;
509   };
510 
511   for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) {
512     // If both are loop invariant, skip to the next triple.
513     if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) &&
514         mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) {
515       // Unless either is a vector index, then be conservative.
516       if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) ||
517           mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp()))
518         return false;
519       continue;
520     }
521     // If identical, skip to the next triple.
522     if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] &&
523         ldTriples[i + 2] == stTriples[i + 2])
524       continue;
525     // If ubound and lbound are the same with a constant offset, skip to the
526     // next triple.
527     if (displacedByConstant(ldTriples[i + 1], stTriples[i]) ||
528         displacedByConstant(stTriples[i + 1], ldTriples[i]))
529       continue;
530     return false;
531   }
532   LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld
533                           << " and " << st << ", which is not a conflict\n");
534   return true;
535 }
536 
537 /// Is there a conflict between the array value that was updated and to be
538 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
539 /// the updated value?
540 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
541                            ArrayMergeStoreOp st) {
542   mlir::Value load;
543   mlir::Value addr = st.getMemref();
544   const bool storeHasPointerType = hasPointerType(addr.getType());
545   for (auto *op : reach)
546     if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) {
547       mlir::Type ldTy = ld.getMemref().getType();
548       if (ld.getMemref() == addr) {
549         if (mutuallyExclusiveSliceRange(ld, st))
550           continue;
551         if (ld.getResult() != st.getOriginal())
552           return true;
553         if (load) {
554           // TODO: extend this to allow checking if the first `load` and this
555           // `ld` are mutually exclusive accesses but not identical.
556           return true;
557         }
558         load = ld;
559       } else if ((hasPointerType(ldTy) || storeHasPointerType)) {
560         // TODO: Use target attribute to restrict this case further.
561         // TODO: Check if types can also allow ruling out some cases. For now,
562         // the fact that equivalences is using pointer attribute to enforce
563         // aliasing is preventing any attempt to do so, and in general, it may
564         // be wrong to use this if any of the types is a complex or a derived
565         // for which it is possible to create a pointer to a part with a
566         // different type than the whole, although this deserve some more
567         // investigation because existing compiler behavior seem to diverge
568         // here.
569         return true;
570       }
571     }
572   return false;
573 }
574 
575 /// Is there an access vector conflict on the array being merged into? If the
576 /// access vectors diverge, then assume that there are potentially overlapping
577 /// loop-carried references.
578 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) {
579   if (mentions.size() < 2)
580     return false;
581   llvm::SmallVector<mlir::Value> indices;
582   LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size()
583                           << " mentions on the list\n");
584   bool valSeen = false;
585   bool refSeen = false;
586   for (auto *op : mentions) {
587     llvm::SmallVector<mlir::Value> compareVector;
588     if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
589       valSeen = true;
590       if (indices.empty()) {
591         indices = u.getIndices();
592         continue;
593       }
594       compareVector = u.getIndices();
595     } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
596       valSeen = true;
597       if (indices.empty()) {
598         indices = f.getIndices();
599         continue;
600       }
601       compareVector = f.getIndices();
602     } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
603       valSeen = true;
604       if (indices.empty()) {
605         indices = f.getIndices();
606         continue;
607       }
608       compareVector = f.getIndices();
609     } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) {
610       refSeen = true;
611       if (indices.empty()) {
612         indices = f.getIndices();
613         continue;
614       }
615       compareVector = f.getIndices();
616     } else if (mlir::isa<ArrayAmendOp>(op)) {
617       refSeen = true;
618       continue;
619     } else {
620       mlir::emitError(op->getLoc(), "unexpected operation in analysis");
621     }
622     if (compareVector.size() != indices.size() ||
623         llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) {
624           return std::get<0>(pair) != std::get<1>(pair);
625         }))
626       return true;
627     LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
628   }
629   return valSeen && refSeen;
630 }
631 
632 /// With element-by-reference semantics, an amended array with more than once
633 /// access to the same loaded array are conservatively considered a conflict.
634 /// Note: the array copy can still be eliminated in subsequent optimizations.
635 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) {
636   LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size()
637                           << '\n');
638   if (mentions.size() < 3)
639     return false;
640   unsigned amendCount = 0;
641   unsigned accessCount = 0;
642   for (auto *op : mentions) {
643     if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) {
644       LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n");
645       return true;
646     }
647     if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) {
648       LLVM_DEBUG(llvm::dbgs()
649                  << "conflict: multiple accesses of array value\n");
650       return true;
651     }
652     if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) {
653       LLVM_DEBUG(llvm::dbgs()
654                  << "conflict: array value has both uses by-value and uses "
655                     "by-reference. conservative assumption.\n");
656       return true;
657     }
658   }
659   return false;
660 }
661 
662 static mlir::Operation *
663 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) {
664   for (auto *op : mentions)
665     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op))
666       return amend.getMemref().getDefiningOp();
667   return {};
668 }
669 
670 // Are any conflicts present? The conflicts detected here are described above.
671 static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
672                              llvm::ArrayRef<mlir::Operation *> mentions,
673                              ArrayMergeStoreOp st) {
674   return conflictOnLoad(reach, st) || conflictOnMerge(mentions);
675 }
676 
677 // Assume that any call to a function that uses host-associations will be
678 // modifying the output array.
679 static bool
680 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) {
681   return llvm::any_of(reaches, [](mlir::Operation *op) {
682     if (auto call = mlir::dyn_cast<fir::CallOp>(op))
683       if (auto callee =
684               call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) {
685         auto module = op->getParentOfType<mlir::ModuleOp>();
686         return hasHostAssociationArgument(
687             module.lookupSymbol<mlir::func::FuncOp>(callee));
688       }
689     return false;
690   });
691 }
692 
693 /// Constructor of the array copy analysis.
694 /// This performs the analysis and saves the intermediate results.
695 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) {
696   topLevelOp->walk([&](Operation *op) {
697     if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
698       llvm::SmallVector<mlir::Operation *> values;
699       ReachCollector::reachingValues(values, st.getSequence());
700       bool callConflict = conservativeCallConflict(values);
701       llvm::SmallVector<mlir::Operation *> mentions;
702       arrayMentions(mentions,
703                     mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp()));
704       bool conflict = conflictDetected(values, mentions, st);
705       bool refConflict = conflictOnReference(mentions);
706       if (callConflict || conflict || refConflict) {
707         LLVM_DEBUG(llvm::dbgs()
708                    << "CONFLICT: copies required for " << st << '\n'
709                    << "   adding conflicts on: " << *op << " and "
710                    << st.getOriginal() << '\n');
711         conflicts.insert(op);
712         conflicts.insert(st.getOriginal().getDefiningOp());
713         if (auto *access = amendingAccess(mentions))
714           amendAccesses.insert(access);
715       }
716       auto *ld = st.getOriginal().getDefiningOp();
717       LLVM_DEBUG(llvm::dbgs()
718                  << "map: adding {" << *ld << " -> " << st << "}\n");
719       useMap.insert({ld, op});
720     } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
721       llvm::SmallVector<mlir::Operation *> mentions;
722       arrayMentions(mentions, load);
723       LLVM_DEBUG(llvm::dbgs() << "process load: " << load
724                               << ", mentions: " << mentions.size() << '\n');
725       for (auto *acc : mentions) {
726         LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n');
727         if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
728                       ArrayModifyOp>(acc)) {
729           if (useMap.count(acc)) {
730             mlir::emitError(
731                 load.getLoc(),
732                 "The parallel semantics of multiple array_merge_stores per "
733                 "array_load are not supported.");
734             continue;
735           }
736           LLVM_DEBUG(llvm::dbgs()
737                      << "map: adding {" << *acc << "} -> {" << load << "}\n");
738           useMap.insert({acc, op});
739         }
740       }
741     }
742   });
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // Conversions for converting out of array value form.
747 //===----------------------------------------------------------------------===//
748 
749 namespace {
750 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
751 public:
752   using OpRewritePattern::OpRewritePattern;
753 
754   mlir::LogicalResult
755   matchAndRewrite(ArrayLoadOp load,
756                   mlir::PatternRewriter &rewriter) const override {
757     LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
758     rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
759     return mlir::success();
760   }
761 };
762 
763 class ArrayMergeStoreConversion
764     : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
765 public:
766   using OpRewritePattern::OpRewritePattern;
767 
768   mlir::LogicalResult
769   matchAndRewrite(ArrayMergeStoreOp store,
770                   mlir::PatternRewriter &rewriter) const override {
771     LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
772     rewriter.eraseOp(store);
773     return mlir::success();
774   }
775 };
776 } // namespace
777 
778 static mlir::Type getEleTy(mlir::Type ty) {
779   auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty));
780   // FIXME: keep ptr/heap/ref information.
781   return ReferenceType::get(eleTy);
782 }
783 
784 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
785 static bool getAdjustedExtents(mlir::Location loc,
786                                mlir::PatternRewriter &rewriter,
787                                ArrayLoadOp arrLoad,
788                                llvm::SmallVectorImpl<mlir::Value> &result,
789                                mlir::Value shape) {
790   bool copyUsingSlice = false;
791   auto *shapeOp = shape.getDefiningOp();
792   if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) {
793     auto e = s.getExtents();
794     result.insert(result.end(), e.begin(), e.end());
795   } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) {
796     auto e = s.getExtents();
797     result.insert(result.end(), e.begin(), e.end());
798   } else {
799     emitFatalError(loc, "not a fir.shape/fir.shape_shift op");
800   }
801   auto idxTy = rewriter.getIndexType();
802   if (factory::isAssumedSize(result)) {
803     // Use slice information to compute the extent of the column.
804     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
805     mlir::Value size = one;
806     if (mlir::Value sliceArg = arrLoad.getSlice()) {
807       if (auto sliceOp =
808               mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) {
809         auto triples = sliceOp.getTriples();
810         const std::size_t tripleSize = triples.size();
811         auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
812         fir::KindMapping kindMap = getKindMapping(module);
813         FirOpBuilder builder(rewriter, kindMap);
814         size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3],
815                                             triples[tripleSize - 2],
816                                             triples[tripleSize - 1], idxTy);
817         copyUsingSlice = true;
818       }
819     }
820     result[result.size() - 1] = size;
821   }
822   return copyUsingSlice;
823 }
824 
825 /// Place the extents of the array load, \p arrLoad, into \p result and
826 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is
827 /// loading a `!fir.box`, code will be generated to read the extents from the
828 /// boxed value, and the retunred shape Op will be built with the extents read
829 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or
830 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true
831 /// if slicing of the output array is to be done in the copy-in/copy-out rather
832 /// than in the elemental computation step.
833 static mlir::Value getOrReadExtentsAndShapeOp(
834     mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad,
835     llvm::SmallVectorImpl<mlir::Value> &result, bool &copyUsingSlice) {
836   assert(result.empty());
837   if (arrLoad->hasAttr(fir::getOptionalAttrName()))
838     fir::emitFatalError(
839         loc, "shapes from array load of OPTIONAL arrays must not be used");
840   if (auto boxTy = arrLoad.getMemref().getType().dyn_cast<BoxType>()) {
841     auto rank =
842         dyn_cast_ptrOrBoxEleTy(boxTy).cast<SequenceType>().getDimension();
843     auto idxTy = rewriter.getIndexType();
844     for (decltype(rank) dim = 0; dim < rank; ++dim) {
845       auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim);
846       auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy,
847                                                 arrLoad.getMemref(), dimVal);
848       result.emplace_back(dimInfo.getResult(1));
849     }
850     if (!arrLoad.getShape()) {
851       auto shapeType = ShapeType::get(rewriter.getContext(), rank);
852       return rewriter.create<ShapeOp>(loc, shapeType, result);
853     }
854     auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>();
855     auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank);
856     llvm::SmallVector<mlir::Value> shapeShiftOperands;
857     for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) {
858       shapeShiftOperands.push_back(lb);
859       shapeShiftOperands.push_back(extent);
860     }
861     return rewriter.create<ShapeShiftOp>(loc, shapeShiftType,
862                                          shapeShiftOperands);
863   }
864   copyUsingSlice =
865       getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape());
866   return arrLoad.getShape();
867 }
868 
869 static mlir::Type toRefType(mlir::Type ty) {
870   if (fir::isa_ref_type(ty))
871     return ty;
872   return fir::ReferenceType::get(ty);
873 }
874 
875 static llvm::SmallVector<mlir::Value>
876 getTypeParamsIfRawData(mlir::Location loc, FirOpBuilder &builder,
877                        ArrayLoadOp arrLoad, mlir::Type ty) {
878   if (ty.isa<BoxType>())
879     return {};
880   return fir::factory::getTypeParams(loc, builder, arrLoad);
881 }
882 
883 static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter,
884                              mlir::Location loc, mlir::Type eleTy,
885                              mlir::Type resTy, mlir::Value alloc,
886                              mlir::Value shape, mlir::Value slice,
887                              mlir::ValueRange indices, ArrayLoadOp load,
888                              bool skipOrig = false) {
889   llvm::SmallVector<mlir::Value> originated;
890   if (skipOrig)
891     originated.assign(indices.begin(), indices.end());
892   else
893     originated = factory::originateIndices(loc, rewriter, alloc.getType(),
894                                            shape, indices);
895   auto seqTy = dyn_cast_ptrOrBoxEleTy(alloc.getType());
896   assert(seqTy && seqTy.isa<SequenceType>());
897   const auto dimension = seqTy.cast<SequenceType>().getDimension();
898   auto module = load->getParentOfType<mlir::ModuleOp>();
899   fir::KindMapping kindMap = getKindMapping(module);
900   FirOpBuilder builder(rewriter, kindMap);
901   auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType());
902   mlir::Value result = rewriter.create<ArrayCoorOp>(
903       loc, eleTy, alloc, shape, slice,
904       llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
905       typeparams);
906   if (dimension < originated.size())
907     result = rewriter.create<fir::CoordinateOp>(
908         loc, resTy, result,
909         llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
910   return result;
911 }
912 
913 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder,
914                                    ArrayLoadOp load, CharacterType charTy) {
915   auto charLenTy = builder.getCharacterLengthType();
916   if (charTy.hasDynamicLen()) {
917     if (load.getMemref().getType().isa<BoxType>()) {
918       // The loaded array is an emboxed value. Get the CHARACTER length from
919       // the box value.
920       auto eleSzInBytes =
921           builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref());
922       auto kindSize =
923           builder.getKindMap().getCharacterBitsize(charTy.getFKind());
924       auto kindByteSize =
925           builder.createIntegerConstant(loc, charLenTy, kindSize / 8);
926       return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes,
927                                                   kindByteSize);
928     }
929     // The loaded array is a (set of) unboxed values. If the CHARACTER's
930     // length is not a constant, it must be provided as a type parameter to
931     // the array_load.
932     auto typeparams = load.getTypeparams();
933     assert(typeparams.size() > 0 && "expected type parameters on array_load");
934     return typeparams.back();
935   }
936   // The typical case: the length of the CHARACTER is a compile-time
937   // constant that is encoded in the type information.
938   return builder.createIntegerConstant(loc, charLenTy, charTy.getLen());
939 }
940 /// Generate a shallow array copy. This is used for both copy-in and copy-out.
941 template <bool CopyIn>
942 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
943                   mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
944                   mlir::Value sliceOp, ArrayLoadOp arrLoad) {
945   auto insPt = rewriter.saveInsertionPoint();
946   llvm::SmallVector<mlir::Value> indices;
947   llvm::SmallVector<mlir::Value> extents;
948   bool copyUsingSlice =
949       getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp);
950   auto idxTy = rewriter.getIndexType();
951   // Build loop nest from column to row.
952   for (auto sh : llvm::reverse(extents)) {
953     auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh);
954     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
955     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
956     auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one);
957     auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one);
958     rewriter.setInsertionPointToStart(loop.getBody());
959     indices.push_back(loop.getInductionVar());
960   }
961   // Reverse the indices so they are in column-major order.
962   std::reverse(indices.begin(), indices.end());
963   auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
964   fir::KindMapping kindMap = getKindMapping(module);
965   FirOpBuilder builder(rewriter, kindMap);
966   auto fromAddr = rewriter.create<ArrayCoorOp>(
967       loc, getEleTy(src.getType()), src, shapeOp,
968       CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
969       factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices),
970       getTypeParamsIfRawData(loc, builder, arrLoad, src.getType()));
971   auto toAddr = rewriter.create<ArrayCoorOp>(
972       loc, getEleTy(dst.getType()), dst, shapeOp,
973       !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
974       factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices),
975       getTypeParamsIfRawData(loc, builder, arrLoad, dst.getType()));
976   auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType()));
977   // Copy from (to) object to (from) temp copy of same object.
978   if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
979     auto len = getCharacterLen(loc, builder, arrLoad, charTy);
980     CharBoxValue toChar(toAddr, len);
981     CharBoxValue fromChar(fromAddr, len);
982     factory::genScalarAssignment(builder, loc, toChar, fromChar);
983   } else {
984     if (hasDynamicSize(eleTy))
985       TODO(loc, "copy element of dynamic size");
986     factory::genScalarAssignment(builder, loc, toAddr, fromAddr);
987   }
988   rewriter.restoreInsertionPoint(insPt);
989 }
990 
991 /// The array load may be either a boxed or unboxed value. If the value is
992 /// boxed, we read the type parameters from the boxed value.
993 static llvm::SmallVector<mlir::Value>
994 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter,
995                            ArrayLoadOp load) {
996   if (load.getTypeparams().empty()) {
997     auto eleTy =
998         unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType()));
999     if (hasDynamicSize(eleTy)) {
1000       if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
1001         assert(load.getMemref().getType().isa<BoxType>());
1002         auto module = load->getParentOfType<mlir::ModuleOp>();
1003         fir::KindMapping kindMap = getKindMapping(module);
1004         FirOpBuilder builder(rewriter, kindMap);
1005         return {getCharacterLen(loc, builder, load, charTy)};
1006       }
1007       TODO(loc, "unhandled dynamic type parameters");
1008     }
1009     return {};
1010   }
1011   return load.getTypeparams();
1012 }
1013 
1014 static llvm::SmallVector<mlir::Value>
1015 findNonconstantExtents(mlir::Type memrefTy,
1016                        llvm::ArrayRef<mlir::Value> extents) {
1017   llvm::SmallVector<mlir::Value> nce;
1018   auto arrTy = unwrapPassByRefType(memrefTy);
1019   auto seqTy = arrTy.cast<SequenceType>();
1020   for (auto [s, x] : llvm::zip(seqTy.getShape(), extents))
1021     if (s == SequenceType::getUnknownExtent())
1022       nce.emplace_back(x);
1023   if (extents.size() > seqTy.getShape().size())
1024     for (auto x : extents.drop_front(seqTy.getShape().size()))
1025       nce.emplace_back(x);
1026   return nce;
1027 }
1028 
1029 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any
1030 /// allocatable direct components of the array elements with an unallocated
1031 /// status. Returns the temporary address as well as a callback to generate the
1032 /// temporary clean-up once it has been used. The clean-up will take care of
1033 /// deallocating all the element allocatable components that may have been
1034 /// allocated while using the temporary.
1035 static std::pair<mlir::Value,
1036                  std::function<void(mlir::PatternRewriter &rewriter)>>
1037 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter,
1038                   ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents,
1039                   mlir::Value shape) {
1040   mlir::Type baseType = load.getMemref().getType();
1041   llvm::SmallVector<mlir::Value> nonconstantExtents =
1042       findNonconstantExtents(baseType, extents);
1043   llvm::SmallVector<mlir::Value> typeParams =
1044       genArrayLoadTypeParameters(loc, rewriter, load);
1045   mlir::Value allocmem = rewriter.create<AllocMemOp>(
1046       loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents);
1047   mlir::Type eleType =
1048       fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType));
1049   if (fir::isRecordWithAllocatableMember(eleType)) {
1050     // The allocatable component descriptors need to be set to a clean
1051     // deallocated status before anything is done with them.
1052     mlir::Value box = rewriter.create<fir::EmboxOp>(
1053         loc, fir::BoxType::get(baseType), allocmem, shape,
1054         /*slice=*/mlir::Value{}, typeParams);
1055     auto module = load->getParentOfType<mlir::ModuleOp>();
1056     fir::KindMapping kindMap = getKindMapping(module);
1057     FirOpBuilder builder(rewriter, kindMap);
1058     runtime::genDerivedTypeInitialize(builder, loc, box);
1059     // Any allocatable component that may have been allocated must be
1060     // deallocated during the clean-up.
1061     auto cleanup = [=](mlir::PatternRewriter &r) {
1062       fir::KindMapping kindMap = getKindMapping(module);
1063       FirOpBuilder builder(r, kindMap);
1064       runtime::genDerivedTypeDestroy(builder, loc, box);
1065       r.create<FreeMemOp>(loc, allocmem);
1066     };
1067     return {allocmem, cleanup};
1068   }
1069   auto cleanup = [=](mlir::PatternRewriter &r) {
1070     r.create<FreeMemOp>(loc, allocmem);
1071   };
1072   return {allocmem, cleanup};
1073 }
1074 
1075 namespace {
1076 /// Conversion of fir.array_update and fir.array_modify Ops.
1077 /// If there is a conflict for the update, then we need to perform a
1078 /// copy-in/copy-out to preserve the original values of the array. If there is
1079 /// no conflict, then it is save to eschew making any copies.
1080 template <typename ArrayOp>
1081 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
1082 public:
1083   // TODO: Implement copy/swap semantics?
1084   explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
1085                                      const ArrayCopyAnalysis &a,
1086                                      const OperationUseMapT &m)
1087       : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
1088 
1089   /// The array_access, \p access, is to be to a cloned copy due to a potential
1090   /// conflict. Uses copy-in/copy-out semantics and not copy/swap.
1091   mlir::Value referenceToClone(mlir::Location loc,
1092                                mlir::PatternRewriter &rewriter,
1093                                ArrayOp access) const {
1094     LLVM_DEBUG(llvm::dbgs()
1095                << "generating copy-in/copy-out loops for " << access << '\n');
1096     auto *op = access.getOperation();
1097     auto *loadOp = useMap.lookup(op);
1098     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1099     auto eleTy = access.getType();
1100     rewriter.setInsertionPoint(loadOp);
1101     // Copy in.
1102     llvm::SmallVector<mlir::Value> extents;
1103     bool copyUsingSlice = false;
1104     auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1105                                               copyUsingSlice);
1106     auto [allocmem, genTempCleanUp] =
1107         allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1108     genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1109                                   load.getMemref(), shapeOp, load.getSlice(),
1110                                   load);
1111     // Generate the reference for the access.
1112     rewriter.setInsertionPoint(op);
1113     auto coor = genCoorOp(
1114         rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, shapeOp,
1115         copyUsingSlice ? mlir::Value{} : load.getSlice(), access.getIndices(),
1116         load, access->hasAttr(factory::attrFortranArrayOffsets()));
1117     // Copy out.
1118     auto *storeOp = useMap.lookup(loadOp);
1119     auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1120     rewriter.setInsertionPoint(storeOp);
1121     // Copy out.
1122     genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(),
1123                                    allocmem, shapeOp, store.getSlice(), load);
1124     genTempCleanUp(rewriter);
1125     return coor;
1126   }
1127 
1128   /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
1129   /// temp and the LHS if the analysis found potential overlaps between the RHS
1130   /// and LHS arrays. The element copy generator must be provided in \p
1131   /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
1132   /// Returns the address of the LHS element inside the loop and the LHS
1133   /// ArrayLoad result.
1134   std::pair<mlir::Value, mlir::Value>
1135   materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
1136                         ArrayOp update,
1137                         const std::function<void(mlir::Value)> &assignElement,
1138                         mlir::Type lhsEltRefType) const {
1139     auto *op = update.getOperation();
1140     auto *loadOp = useMap.lookup(op);
1141     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1142     LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
1143     if (analysis.hasPotentialConflict(loadOp)) {
1144       // If there is a conflict between the arrays, then we copy the lhs array
1145       // to a temporary, update the temporary, and copy the temporary back to
1146       // the lhs array. This yields Fortran's copy-in copy-out array semantics.
1147       LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
1148       rewriter.setInsertionPoint(loadOp);
1149       // Copy in.
1150       llvm::SmallVector<mlir::Value> extents;
1151       bool copyUsingSlice = false;
1152       auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1153                                                 copyUsingSlice);
1154       auto [allocmem, genTempCleanUp] =
1155           allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1156 
1157       genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1158                                     load.getMemref(), shapeOp, load.getSlice(),
1159                                     load);
1160       rewriter.setInsertionPoint(op);
1161       auto coor = genCoorOp(
1162           rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
1163           shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
1164           update.getIndices(), load,
1165           update->hasAttr(factory::attrFortranArrayOffsets()));
1166       assignElement(coor);
1167       auto *storeOp = useMap.lookup(loadOp);
1168       auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1169       rewriter.setInsertionPoint(storeOp);
1170       // Copy out.
1171       genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter,
1172                                      store.getMemref(), allocmem, shapeOp,
1173                                      store.getSlice(), load);
1174       genTempCleanUp(rewriter);
1175       return {coor, load.getResult()};
1176     }
1177     // Otherwise, when there is no conflict (a possible loop-carried
1178     // dependence), the lhs array can be updated in place.
1179     LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
1180     rewriter.setInsertionPoint(op);
1181     auto coorTy = getEleTy(load.getType());
1182     auto coor =
1183         genCoorOp(rewriter, loc, coorTy, lhsEltRefType, load.getMemref(),
1184                   load.getShape(), load.getSlice(), update.getIndices(), load,
1185                   update->hasAttr(factory::attrFortranArrayOffsets()));
1186     assignElement(coor);
1187     return {coor, load.getResult()};
1188   }
1189 
1190 protected:
1191   const ArrayCopyAnalysis &analysis;
1192   const OperationUseMapT &useMap;
1193 };
1194 
1195 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
1196 public:
1197   explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
1198                                  const ArrayCopyAnalysis &a,
1199                                  const OperationUseMapT &m)
1200       : ArrayUpdateConversionBase{ctx, a, m} {}
1201 
1202   mlir::LogicalResult
1203   matchAndRewrite(ArrayUpdateOp update,
1204                   mlir::PatternRewriter &rewriter) const override {
1205     auto loc = update.getLoc();
1206     auto assignElement = [&](mlir::Value coor) {
1207       auto input = update.getMerge();
1208       if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) {
1209         emitFatalError(loc, "array_update on references not supported");
1210       } else {
1211         rewriter.create<fir::StoreOp>(loc, input, coor);
1212       }
1213     };
1214     auto lhsEltRefType = toRefType(update.getMerge().getType());
1215     auto [_, lhsLoadResult] = materializeAssignment(
1216         loc, rewriter, update, assignElement, lhsEltRefType);
1217     update.replaceAllUsesWith(lhsLoadResult);
1218     rewriter.replaceOp(update, lhsLoadResult);
1219     return mlir::success();
1220   }
1221 };
1222 
1223 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
1224 public:
1225   explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
1226                                  const ArrayCopyAnalysis &a,
1227                                  const OperationUseMapT &m)
1228       : ArrayUpdateConversionBase{ctx, a, m} {}
1229 
1230   mlir::LogicalResult
1231   matchAndRewrite(ArrayModifyOp modify,
1232                   mlir::PatternRewriter &rewriter) const override {
1233     auto loc = modify.getLoc();
1234     auto assignElement = [](mlir::Value) {
1235       // Assignment already materialized by lowering using lhs element address.
1236     };
1237     auto lhsEltRefType = modify.getResult(0).getType();
1238     auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
1239         loc, rewriter, modify, assignElement, lhsEltRefType);
1240     modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1241     rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1242     return mlir::success();
1243   }
1244 };
1245 
1246 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
1247 public:
1248   explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
1249                                 const OperationUseMapT &m)
1250       : OpRewritePattern{ctx}, useMap{m} {}
1251 
1252   mlir::LogicalResult
1253   matchAndRewrite(ArrayFetchOp fetch,
1254                   mlir::PatternRewriter &rewriter) const override {
1255     auto *op = fetch.getOperation();
1256     rewriter.setInsertionPoint(op);
1257     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1258     auto loc = fetch.getLoc();
1259     auto coor = genCoorOp(
1260         rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()),
1261         load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(),
1262         load, fetch->hasAttr(factory::attrFortranArrayOffsets()));
1263     if (isa_ref_type(fetch.getType()))
1264       rewriter.replaceOp(fetch, coor);
1265     else
1266       rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
1267     return mlir::success();
1268   }
1269 
1270 private:
1271   const OperationUseMapT &useMap;
1272 };
1273 
1274 /// As array_access op is like an array_fetch op, except that it does not imply
1275 /// a load op. (It operates in the reference domain.)
1276 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> {
1277 public:
1278   explicit ArrayAccessConversion(mlir::MLIRContext *ctx,
1279                                  const ArrayCopyAnalysis &a,
1280                                  const OperationUseMapT &m)
1281       : ArrayUpdateConversionBase{ctx, a, m} {}
1282 
1283   mlir::LogicalResult
1284   matchAndRewrite(ArrayAccessOp access,
1285                   mlir::PatternRewriter &rewriter) const override {
1286     auto *op = access.getOperation();
1287     auto loc = access.getLoc();
1288     if (analysis.inAmendAccessSet(op)) {
1289       // This array_access is associated with an array_amend and there is a
1290       // conflict. Make a copy to store into.
1291       auto result = referenceToClone(loc, rewriter, access);
1292       access.replaceAllUsesWith(result);
1293       rewriter.replaceOp(access, result);
1294       return mlir::success();
1295     }
1296     rewriter.setInsertionPoint(op);
1297     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1298     auto coor = genCoorOp(
1299         rewriter, loc, getEleTy(load.getType()), toRefType(access.getType()),
1300         load.getMemref(), load.getShape(), load.getSlice(), access.getIndices(),
1301         load, access->hasAttr(factory::attrFortranArrayOffsets()));
1302     rewriter.replaceOp(access, coor);
1303     return mlir::success();
1304   }
1305 };
1306 
1307 /// An array_amend op is a marker to record which array access is being used to
1308 /// update an array value. After this pass runs, an array_amend has no
1309 /// semantics. We rewrite these to undefined values here to remove them while
1310 /// preserving SSA form.
1311 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> {
1312 public:
1313   explicit ArrayAmendConversion(mlir::MLIRContext *ctx)
1314       : OpRewritePattern{ctx} {}
1315 
1316   mlir::LogicalResult
1317   matchAndRewrite(ArrayAmendOp amend,
1318                   mlir::PatternRewriter &rewriter) const override {
1319     auto *op = amend.getOperation();
1320     rewriter.setInsertionPoint(op);
1321     auto loc = amend.getLoc();
1322     auto undef = rewriter.create<UndefOp>(loc, amend.getType());
1323     rewriter.replaceOp(amend, undef.getResult());
1324     return mlir::success();
1325   }
1326 };
1327 
1328 class ArrayValueCopyConverter
1329     : public ArrayValueCopyBase<ArrayValueCopyConverter> {
1330 public:
1331   void runOnOperation() override {
1332     auto func = getOperation();
1333     LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
1334                             << func.getName() << "'\n");
1335     auto *context = &getContext();
1336 
1337     // Perform the conflict analysis.
1338     const auto &analysis = getAnalysis<ArrayCopyAnalysis>();
1339     const auto &useMap = analysis.getUseMap();
1340 
1341     mlir::RewritePatternSet patterns1(context);
1342     patterns1.insert<ArrayFetchConversion>(context, useMap);
1343     patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
1344     patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
1345     patterns1.insert<ArrayAccessConversion>(context, analysis, useMap);
1346     patterns1.insert<ArrayAmendConversion>(context);
1347     mlir::ConversionTarget target(*context);
1348     target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
1349                            mlir::arith::ArithmeticDialect,
1350                            mlir::func::FuncDialect>();
1351     target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp,
1352                         ArrayUpdateOp, ArrayModifyOp>();
1353     // Rewrite the array fetch and array update ops.
1354     if (mlir::failed(
1355             mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
1356       mlir::emitError(mlir::UnknownLoc::get(context),
1357                       "failure in array-value-copy pass, phase 1");
1358       signalPassFailure();
1359     }
1360 
1361     mlir::RewritePatternSet patterns2(context);
1362     patterns2.insert<ArrayLoadConversion>(context);
1363     patterns2.insert<ArrayMergeStoreConversion>(context);
1364     target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
1365     if (mlir::failed(
1366             mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
1367       mlir::emitError(mlir::UnknownLoc::get(context),
1368                       "failure in array-value-copy pass, phase 2");
1369       signalPassFailure();
1370     }
1371   }
1372 };
1373 } // namespace
1374 
1375 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() {
1376   return std::make_unique<ArrayValueCopyConverter>();
1377 }
1378