xref: /llvm-project/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp (revision 5e50dd048e3a20cde5da5d7a754dfee775ef35d6)
1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Lower/Todo.h"
11 #include "flang/Optimizer/Builder/Array.h"
12 #include "flang/Optimizer/Builder/BoxValue.h"
13 #include "flang/Optimizer/Builder/FIRBuilder.h"
14 #include "flang/Optimizer/Builder/Factory.h"
15 #include "flang/Optimizer/Builder/Runtime/Derived.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
18 #include "flang/Optimizer/Support/FIRContext.h"
19 #include "flang/Optimizer/Transforms/Passes.h"
20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "flang-array-value-copy"
26 
27 using namespace fir;
28 using namespace mlir;
29 
30 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
31 
32 namespace {
33 
34 /// Array copy analysis.
35 /// Perform an interference analysis between array values.
36 ///
37 /// Lowering will generate a sequence of the following form.
38 /// ```mlir
39 ///   %a_1 = fir.array_load %array_1(%shape) : ...
40 ///   ...
41 ///   %a_j = fir.array_load %array_j(%shape) : ...
42 ///   ...
43 ///   %a_n = fir.array_load %array_n(%shape) : ...
44 ///     ...
45 ///     %v_i = fir.array_fetch %a_i, ...
46 ///     %a_j1 = fir.array_update %a_j, ...
47 ///     ...
48 ///   fir.array_merge_store %a_j, %a_jn to %array_j : ...
49 /// ```
50 ///
51 /// The analysis is to determine if there are any conflicts. A conflict is when
52 /// one the following cases occurs.
53 ///
54 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
55 /// loaded from the same array memory reference (array_j) but with a different
56 /// shape as the other array values a_i, where i != j. [Possible overlapping
57 /// arrays.]
58 ///
59 /// 2. There is either an array_fetch or array_update of a_j with a different
60 /// set of index values. [Possible loop-carried dependence.]
61 ///
62 /// If none of the array values overlap in storage and the accesses are not
63 /// loop-carried, then the arrays are conflict-free and no copies are required.
64 class ArrayCopyAnalysis {
65 public:
66   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis)
67 
68   using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
69   using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
70   using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>;
71   using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>;
72 
73   ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); }
74 
75   mlir::Operation *getOperation() const { return operation; }
76 
77   /// Return true iff the `array_merge_store` has potential conflicts.
78   bool hasPotentialConflict(mlir::Operation *op) const {
79     LLVM_DEBUG(llvm::dbgs()
80                << "looking for a conflict on " << *op
81                << " and the set has a total of " << conflicts.size() << '\n');
82     return conflicts.contains(op);
83   }
84 
85   /// Return the use map.
86   /// The use map maps array access, amend, fetch and update operations back to
87   /// the array load that is the original source of the array value.
88   /// It maps an array_load to an array_merge_store, if and only if the loaded
89   /// array value has pending modifications to be merged.
90   const OperationUseMapT &getUseMap() const { return useMap; }
91 
92   /// Return the set of array_access ops directly associated with array_amend
93   /// ops.
94   bool inAmendAccessSet(mlir::Operation *op) const {
95     return amendAccesses.count(op);
96   }
97 
98   /// For ArrayLoad `load`, return the transitive set of all OpOperands.
99   UseSetT getLoadUseSet(mlir::Operation *load) const {
100     assert(loadMapSets.count(load) && "analysis missed an array load?");
101     return loadMapSets.lookup(load);
102   }
103 
104   void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions,
105                      ArrayLoadOp load);
106 
107 private:
108   void construct(mlir::Operation *topLevelOp);
109 
110   mlir::Operation *operation; // operation that analysis ran upon
111   ConflictSetT conflicts;     // set of conflicts (loads and merge stores)
112   OperationUseMapT useMap;
113   LoadMapSetsT loadMapSets;
114   // Set of array_access ops associated with array_amend ops.
115   AmendAccessSetT amendAccesses;
116 };
117 } // namespace
118 
119 namespace {
120 /// Helper class to collect all array operations that produced an array value.
121 class ReachCollector {
122 public:
123   ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
124                  mlir::Region *loopRegion)
125       : reach{reach}, loopRegion{loopRegion} {}
126 
127   void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) {
128     if (range.empty()) {
129       collectArrayMentionFrom(op, mlir::Value{});
130       return;
131     }
132     for (mlir::Value v : range)
133       collectArrayMentionFrom(v);
134   }
135 
136   // Collect all the array_access ops in `block`. This recursively looks into
137   // blocks in ops with regions.
138   // FIXME: This is temporarily relying on the array_amend appearing in a
139   // do_loop Region.  This phase ordering assumption can be eliminated by using
140   // dominance information to find the array_access ops or by scanning the
141   // transitive closure of the amending array_access's users and the defs that
142   // reach them.
143   void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result,
144                        mlir::Block *block) {
145     for (auto &op : *block) {
146       if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) {
147         LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n');
148         result.push_back(access);
149         continue;
150       }
151       for (auto &region : op.getRegions())
152         for (auto &bb : region.getBlocks())
153           collectAccesses(result, &bb);
154     }
155   }
156 
157   void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) {
158     // `val` is defined by an Op, process the defining Op.
159     // If `val` is defined by a region containing Op, we want to drill down
160     // and through that Op's region(s).
161     LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
162     auto popFn = [&](auto rop) {
163       assert(val && "op must have a result value");
164       auto resNum = val.cast<mlir::OpResult>().getResultNumber();
165       llvm::SmallVector<mlir::Value> results;
166       rop.resultToSourceOps(results, resNum);
167       for (auto u : results)
168         collectArrayMentionFrom(u);
169     };
170     if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) {
171       popFn(rop);
172       return;
173     }
174     if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) {
175       popFn(rop);
176       return;
177     }
178     if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
179       popFn(rop);
180       return;
181     }
182     if (auto box = mlir::dyn_cast<EmboxOp>(op)) {
183       for (auto *user : box.getMemref().getUsers())
184         if (user != op)
185           collectArrayMentionFrom(user, user->getResults());
186       return;
187     }
188     if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
189       if (opIsInsideLoops(mergeStore))
190         collectArrayMentionFrom(mergeStore.getSequence());
191       return;
192     }
193 
194     if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
195       // Look for any stores inside the loops, and collect an array operation
196       // that produced the value being stored to it.
197       for (auto *user : op->getUsers())
198         if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
199           if (opIsInsideLoops(store))
200             collectArrayMentionFrom(store.getValue());
201       return;
202     }
203 
204     // Scan the uses of amend's memref
205     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) {
206       reach.push_back(op);
207       llvm::SmallVector<ArrayAccessOp> accesses;
208       collectAccesses(accesses, op->getBlock());
209       for (auto access : accesses)
210         collectArrayMentionFrom(access.getResult());
211     }
212 
213     // Otherwise, Op does not contain a region so just chase its operands.
214     if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp,
215                   ArrayFetchOp>(op)) {
216       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
217       reach.push_back(op);
218     }
219 
220     // Include all array_access ops using an array_load.
221     if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op))
222       for (auto *user : arrLd.getResult().getUsers())
223         if (mlir::isa<ArrayAccessOp>(user)) {
224           LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n");
225           reach.push_back(user);
226         }
227 
228     // Array modify assignment is performed on the result. So the analysis must
229     // look at the what is done with the result.
230     if (mlir::isa<ArrayModifyOp>(op))
231       for (auto *user : op->getResult(0).getUsers())
232         followUsers(user);
233 
234     if (mlir::isa<fir::CallOp>(op)) {
235       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
236       reach.push_back(op);
237     }
238 
239     for (auto u : op->getOperands())
240       collectArrayMentionFrom(u);
241   }
242 
243   void collectArrayMentionFrom(mlir::BlockArgument ba) {
244     auto *parent = ba.getOwner()->getParentOp();
245     // If inside an Op holding a region, the block argument corresponds to an
246     // argument passed to the containing Op.
247     auto popFn = [&](auto rop) {
248       collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
249     };
250     if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
251       popFn(rop);
252       return;
253     }
254     if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
255       popFn(rop);
256       return;
257     }
258     // Otherwise, a block argument is provided via the pred blocks.
259     for (auto *pred : ba.getOwner()->getPredecessors()) {
260       auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
261       collectArrayMentionFrom(u);
262     }
263   }
264 
265   // Recursively trace operands to find all array operations relating to the
266   // values merged.
267   void collectArrayMentionFrom(mlir::Value val) {
268     if (!val || visited.contains(val))
269       return;
270     visited.insert(val);
271 
272     // Process a block argument.
273     if (auto ba = val.dyn_cast<mlir::BlockArgument>()) {
274       collectArrayMentionFrom(ba);
275       return;
276     }
277 
278     // Process an Op.
279     if (auto *op = val.getDefiningOp()) {
280       collectArrayMentionFrom(op, val);
281       return;
282     }
283 
284     emitFatalError(val.getLoc(), "unhandled value");
285   }
286 
287   /// Return all ops that produce the array value that is stored into the
288   /// `array_merge_store`.
289   static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
290                              mlir::Value seq) {
291     reach.clear();
292     mlir::Region *loopRegion = nullptr;
293     if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp()))
294       loopRegion = &doLoop->getRegion(0);
295     ReachCollector collector(reach, loopRegion);
296     collector.collectArrayMentionFrom(seq);
297   }
298 
299 private:
300   /// Is \op inside the loop nest region ?
301   /// FIXME: replace this structural dependence with graph properties.
302   bool opIsInsideLoops(mlir::Operation *op) const {
303     auto *region = op->getParentRegion();
304     while (region) {
305       if (region == loopRegion)
306         return true;
307       region = region->getParentRegion();
308     }
309     return false;
310   }
311 
312   /// Recursively trace the use of an operation results, calling
313   /// collectArrayMentionFrom on the direct and indirect user operands.
314   void followUsers(mlir::Operation *op) {
315     for (auto userOperand : op->getOperands())
316       collectArrayMentionFrom(userOperand);
317     // Go through potential converts/coordinate_op.
318     for (auto indirectUser : op->getUsers())
319       followUsers(indirectUser);
320   }
321 
322   llvm::SmallVectorImpl<mlir::Operation *> &reach;
323   llvm::SmallPtrSet<mlir::Value, 16> visited;
324   /// Region of the loops nest that produced the array value.
325   mlir::Region *loopRegion;
326 };
327 } // namespace
328 
329 /// Find all the array operations that access the array value that is loaded by
330 /// the array load operation, `load`.
331 void ArrayCopyAnalysis::arrayMentions(
332     llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) {
333   mentions.clear();
334   auto lmIter = loadMapSets.find(load);
335   if (lmIter != loadMapSets.end()) {
336     for (auto *opnd : lmIter->second) {
337       auto *owner = opnd->getOwner();
338       if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
339                     ArrayModifyOp>(owner))
340         mentions.push_back(owner);
341     }
342     return;
343   }
344 
345   UseSetT visited;
346   llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
347 
348   auto appendToQueue = [&](mlir::Value val) {
349     for (auto &use : val.getUses())
350       if (!visited.count(&use)) {
351         visited.insert(&use);
352         queue.push_back(&use);
353       }
354   };
355 
356   // Build the set of uses of `original`.
357   // let USES = { uses of original fir.load }
358   appendToQueue(load);
359 
360   // Process the worklist until done.
361   while (!queue.empty()) {
362     mlir::OpOperand *operand = queue.pop_back_val();
363     mlir::Operation *owner = operand->getOwner();
364     if (!owner)
365       continue;
366     auto structuredLoop = [&](auto ro) {
367       if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
368         int64_t arg = blockArg.getArgNumber();
369         mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1);
370         appendToQueue(output);
371         appendToQueue(blockArg);
372       }
373     };
374     // TODO: this need to be updated to use the control-flow interface.
375     auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
376       if (operands.empty())
377         return;
378 
379       // Check if this operand is within the range.
380       unsigned operandIndex = operand->getOperandNumber();
381       unsigned operandsStart = operands.getBeginOperandIndex();
382       if (operandIndex < operandsStart ||
383           operandIndex >= (operandsStart + operands.size()))
384         return;
385 
386       // Index the successor.
387       unsigned argIndex = operandIndex - operandsStart;
388       appendToQueue(dest->getArgument(argIndex));
389     };
390     // Thread uses into structured loop bodies and return value uses.
391     if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
392       structuredLoop(ro);
393     } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
394       structuredLoop(ro);
395     } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
396       // Thread any uses of fir.if that return the marked array value.
397       mlir::Operation *parent = rs->getParentRegion()->getParentOp();
398       if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent))
399         appendToQueue(ifOp.getResult(operand->getOperandNumber()));
400     } else if (mlir::isa<ArrayFetchOp>(owner)) {
401       // Keep track of array value fetches.
402       LLVM_DEBUG(llvm::dbgs()
403                  << "add fetch {" << *owner << "} to array value set\n");
404       mentions.push_back(owner);
405     } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
406       // Keep track of array value updates and thread the return value uses.
407       LLVM_DEBUG(llvm::dbgs()
408                  << "add update {" << *owner << "} to array value set\n");
409       mentions.push_back(owner);
410       appendToQueue(update.getResult());
411     } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
412       // Keep track of array value modification and thread the return value
413       // uses.
414       LLVM_DEBUG(llvm::dbgs()
415                  << "add modify {" << *owner << "} to array value set\n");
416       mentions.push_back(owner);
417       appendToQueue(update.getResult(1));
418     } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) {
419       mentions.push_back(owner);
420     } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) {
421       mentions.push_back(owner);
422       appendToQueue(amend.getResult());
423     } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
424       branchOp(br.getDest(), br.getDestOperands());
425     } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
426       branchOp(br.getTrueDest(), br.getTrueOperands());
427       branchOp(br.getFalseDest(), br.getFalseOperands());
428     } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
429       // do nothing
430     } else {
431       llvm::report_fatal_error("array value reached unexpected op");
432     }
433   }
434   loadMapSets.insert({load, visited});
435 }
436 
437 static bool hasPointerType(mlir::Type type) {
438   if (auto boxTy = type.dyn_cast<BoxType>())
439     type = boxTy.getEleTy();
440   return type.isa<fir::PointerType>();
441 }
442 
443 // This is a NF performance hack. It makes a simple test that the slices of the
444 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive.
445 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) {
446   // If the same array_load, then no further testing is warranted.
447   if (ld.getResult() == st.getOriginal())
448     return false;
449 
450   auto getSliceOp = [](mlir::Value val) -> SliceOp {
451     if (!val)
452       return {};
453     auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp());
454     if (!sliceOp)
455       return {};
456     return sliceOp;
457   };
458 
459   auto ldSlice = getSliceOp(ld.getSlice());
460   auto stSlice = getSliceOp(st.getSlice());
461   if (!ldSlice || !stSlice)
462     return false;
463 
464   // Resign on subobject slices.
465   if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() ||
466       !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty())
467     return false;
468 
469   // Crudely test that the two slices do not overlap by looking for the
470   // following general condition. If the slices look like (i:j) and (j+1:k) then
471   // these ranges do not overlap. The addend must be a constant.
472   auto ldTriples = ldSlice.getTriples();
473   auto stTriples = stSlice.getTriples();
474   const auto size = ldTriples.size();
475   if (size != stTriples.size())
476     return false;
477 
478   auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
479     auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
480       auto *op = v.getDefiningOp();
481       while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op))
482         op = conv.getValue().getDefiningOp();
483       return op;
484     };
485 
486     auto isPositiveConstant = [](mlir::Value v) -> bool {
487       if (auto conOp =
488               mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
489         if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>())
490           return iattr.getInt() > 0;
491       return false;
492     };
493 
494     auto *op1 = removeConvert(v1);
495     auto *op2 = removeConvert(v2);
496     if (!op1 || !op2)
497       return false;
498     if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
499       if ((addi.getLhs().getDefiningOp() == op1 &&
500            isPositiveConstant(addi.getRhs())) ||
501           (addi.getRhs().getDefiningOp() == op1 &&
502            isPositiveConstant(addi.getLhs())))
503         return true;
504     if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
505       if (subi.getLhs().getDefiningOp() == op2 &&
506           isPositiveConstant(subi.getRhs()))
507         return true;
508     return false;
509   };
510 
511   for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) {
512     // If both are loop invariant, skip to the next triple.
513     if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) &&
514         mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) {
515       // Unless either is a vector index, then be conservative.
516       if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) ||
517           mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp()))
518         return false;
519       continue;
520     }
521     // If identical, skip to the next triple.
522     if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] &&
523         ldTriples[i + 2] == stTriples[i + 2])
524       continue;
525     // If ubound and lbound are the same with a constant offset, skip to the
526     // next triple.
527     if (displacedByConstant(ldTriples[i + 1], stTriples[i]) ||
528         displacedByConstant(stTriples[i + 1], ldTriples[i]))
529       continue;
530     return false;
531   }
532   LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld
533                           << " and " << st << ", which is not a conflict\n");
534   return true;
535 }
536 
537 /// Is there a conflict between the array value that was updated and to be
538 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
539 /// the updated value?
540 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
541                            ArrayMergeStoreOp st) {
542   mlir::Value load;
543   mlir::Value addr = st.getMemref();
544   const bool storeHasPointerType = hasPointerType(addr.getType());
545   for (auto *op : reach)
546     if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) {
547       mlir::Type ldTy = ld.getMemref().getType();
548       if (ld.getMemref() == addr) {
549         if (mutuallyExclusiveSliceRange(ld, st))
550           continue;
551         if (ld.getResult() != st.getOriginal())
552           return true;
553         if (load) {
554           // TODO: extend this to allow checking if the first `load` and this
555           // `ld` are mutually exclusive accesses but not identical.
556           return true;
557         }
558         load = ld;
559       } else if ((hasPointerType(ldTy) || storeHasPointerType)) {
560         // TODO: Use target attribute to restrict this case further.
561         // TODO: Check if types can also allow ruling out some cases. For now,
562         // the fact that equivalences is using pointer attribute to enforce
563         // aliasing is preventing any attempt to do so, and in general, it may
564         // be wrong to use this if any of the types is a complex or a derived
565         // for which it is possible to create a pointer to a part with a
566         // different type than the whole, although this deserve some more
567         // investigation because existing compiler behavior seem to diverge
568         // here.
569         return true;
570       }
571     }
572   return false;
573 }
574 
575 /// Is there an access vector conflict on the array being merged into? If the
576 /// access vectors diverge, then assume that there are potentially overlapping
577 /// loop-carried references.
578 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) {
579   if (mentions.size() < 2)
580     return false;
581   llvm::SmallVector<mlir::Value> indices;
582   LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size()
583                           << " mentions on the list\n");
584   bool valSeen = false;
585   bool refSeen = false;
586   for (auto *op : mentions) {
587     llvm::SmallVector<mlir::Value> compareVector;
588     if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
589       valSeen = true;
590       if (indices.empty()) {
591         indices = u.getIndices();
592         continue;
593       }
594       compareVector = u.getIndices();
595     } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
596       valSeen = true;
597       if (indices.empty()) {
598         indices = f.getIndices();
599         continue;
600       }
601       compareVector = f.getIndices();
602     } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
603       valSeen = true;
604       if (indices.empty()) {
605         indices = f.getIndices();
606         continue;
607       }
608       compareVector = f.getIndices();
609     } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) {
610       refSeen = true;
611       if (indices.empty()) {
612         indices = f.getIndices();
613         continue;
614       }
615       compareVector = f.getIndices();
616     } else if (mlir::isa<ArrayAmendOp>(op)) {
617       refSeen = true;
618       continue;
619     } else {
620       mlir::emitError(op->getLoc(), "unexpected operation in analysis");
621     }
622     if (compareVector.size() != indices.size() ||
623         llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) {
624           return std::get<0>(pair) != std::get<1>(pair);
625         }))
626       return true;
627     LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
628   }
629   return valSeen && refSeen;
630 }
631 
632 /// With element-by-reference semantics, an amended array with more than once
633 /// access to the same loaded array are conservatively considered a conflict.
634 /// Note: the array copy can still be eliminated in subsequent optimizations.
635 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) {
636   LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size()
637                           << '\n');
638   if (mentions.size() < 3)
639     return false;
640   unsigned amendCount = 0;
641   unsigned accessCount = 0;
642   for (auto *op : mentions) {
643     if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) {
644       LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n");
645       return true;
646     }
647     if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) {
648       LLVM_DEBUG(llvm::dbgs()
649                  << "conflict: multiple accesses of array value\n");
650       return true;
651     }
652     if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) {
653       LLVM_DEBUG(llvm::dbgs()
654                  << "conflict: array value has both uses by-value and uses "
655                     "by-reference. conservative assumption.\n");
656       return true;
657     }
658   }
659   return false;
660 }
661 
662 static mlir::Operation *
663 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) {
664   for (auto *op : mentions)
665     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op))
666       return amend.getMemref().getDefiningOp();
667   return {};
668 }
669 
670 // Are either of types of conflicts present?
671 inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
672                              llvm::ArrayRef<mlir::Operation *> accesses,
673                              ArrayMergeStoreOp st) {
674   return conflictOnLoad(reach, st) || conflictOnMerge(accesses);
675 }
676 
677 // Assume that any call to a function that uses host-associations will be
678 // modifying the output array.
679 static bool
680 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) {
681   return llvm::any_of(reaches, [](mlir::Operation *op) {
682     if (auto call = mlir::dyn_cast<fir::CallOp>(op))
683       if (auto callee =
684               call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) {
685         auto module = op->getParentOfType<mlir::ModuleOp>();
686         return hasHostAssociationArgument(
687             module.lookupSymbol<mlir::FuncOp>(callee));
688       }
689     return false;
690   });
691 }
692 
693 /// Constructor of the array copy analysis.
694 /// This performs the analysis and saves the intermediate results.
695 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) {
696   topLevelOp->walk([&](Operation *op) {
697     if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
698       llvm::SmallVector<mlir::Operation *> values;
699       ReachCollector::reachingValues(values, st.getSequence());
700       bool callConflict = conservativeCallConflict(values);
701       llvm::SmallVector<mlir::Operation *> mentions;
702       arrayMentions(mentions,
703                     mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp()));
704       bool conflict = conflictDetected(values, mentions, st);
705       bool refConflict = conflictOnReference(mentions);
706       if (callConflict || conflict || refConflict) {
707         LLVM_DEBUG(llvm::dbgs()
708                    << "CONFLICT: copies required for " << st << '\n'
709                    << "   adding conflicts on: " << op << " and "
710                    << st.getOriginal() << '\n');
711         conflicts.insert(op);
712         conflicts.insert(st.getOriginal().getDefiningOp());
713         if (auto *access = amendingAccess(mentions))
714           amendAccesses.insert(access);
715       }
716       auto *ld = st.getOriginal().getDefiningOp();
717       LLVM_DEBUG(llvm::dbgs()
718                  << "map: adding {" << *ld << " -> " << st << "}\n");
719       useMap.insert({ld, op});
720     } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
721       llvm::SmallVector<mlir::Operation *> mentions;
722       arrayMentions(mentions, load);
723       LLVM_DEBUG(llvm::dbgs() << "process load: " << load
724                               << ", mentions: " << mentions.size() << '\n');
725       for (auto *acc : mentions) {
726         LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n');
727         if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
728                       ArrayModifyOp>(acc)) {
729           if (useMap.count(acc)) {
730             mlir::emitError(
731                 load.getLoc(),
732                 "The parallel semantics of multiple array_merge_stores per "
733                 "array_load are not supported.");
734             continue;
735           }
736           LLVM_DEBUG(llvm::dbgs()
737                      << "map: adding {" << *acc << "} -> {" << load << "}\n");
738           useMap.insert({acc, op});
739         }
740       }
741     }
742   });
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // Conversions for converting out of array value form.
747 //===----------------------------------------------------------------------===//
748 
749 namespace {
750 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
751 public:
752   using OpRewritePattern::OpRewritePattern;
753 
754   mlir::LogicalResult
755   matchAndRewrite(ArrayLoadOp load,
756                   mlir::PatternRewriter &rewriter) const override {
757     LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
758     rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
759     return mlir::success();
760   }
761 };
762 
763 class ArrayMergeStoreConversion
764     : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
765 public:
766   using OpRewritePattern::OpRewritePattern;
767 
768   mlir::LogicalResult
769   matchAndRewrite(ArrayMergeStoreOp store,
770                   mlir::PatternRewriter &rewriter) const override {
771     LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
772     rewriter.eraseOp(store);
773     return mlir::success();
774   }
775 };
776 } // namespace
777 
778 static mlir::Type getEleTy(mlir::Type ty) {
779   auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty));
780   // FIXME: keep ptr/heap/ref information.
781   return ReferenceType::get(eleTy);
782 }
783 
784 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
785 static bool getAdjustedExtents(mlir::Location loc,
786                                mlir::PatternRewriter &rewriter,
787                                ArrayLoadOp arrLoad,
788                                llvm::SmallVectorImpl<mlir::Value> &result,
789                                mlir::Value shape) {
790   bool copyUsingSlice = false;
791   auto *shapeOp = shape.getDefiningOp();
792   if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) {
793     auto e = s.getExtents();
794     result.insert(result.end(), e.begin(), e.end());
795   } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) {
796     auto e = s.getExtents();
797     result.insert(result.end(), e.begin(), e.end());
798   } else {
799     emitFatalError(loc, "not a fir.shape/fir.shape_shift op");
800   }
801   auto idxTy = rewriter.getIndexType();
802   if (factory::isAssumedSize(result)) {
803     // Use slice information to compute the extent of the column.
804     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
805     mlir::Value size = one;
806     if (mlir::Value sliceArg = arrLoad.getSlice()) {
807       if (auto sliceOp =
808               mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) {
809         auto triples = sliceOp.getTriples();
810         const std::size_t tripleSize = triples.size();
811         auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
812         FirOpBuilder builder(rewriter, getKindMapping(module));
813         size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3],
814                                             triples[tripleSize - 2],
815                                             triples[tripleSize - 1], idxTy);
816         copyUsingSlice = true;
817       }
818     }
819     result[result.size() - 1] = size;
820   }
821   return copyUsingSlice;
822 }
823 
824 /// Place the extents of the array load, \p arrLoad, into \p result and
825 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is
826 /// loading a `!fir.box`, code will be generated to read the extents from the
827 /// boxed value, and the retunred shape Op will be built with the extents read
828 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or
829 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true
830 /// if slicing of the output array is to be done in the copy-in/copy-out rather
831 /// than in the elemental computation step.
832 static mlir::Value getOrReadExtentsAndShapeOp(
833     mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad,
834     llvm::SmallVectorImpl<mlir::Value> &result, bool &copyUsingSlice) {
835   assert(result.empty());
836   if (arrLoad->hasAttr(fir::getOptionalAttrName()))
837     fir::emitFatalError(
838         loc, "shapes from array load of OPTIONAL arrays must not be used");
839   if (auto boxTy = arrLoad.getMemref().getType().dyn_cast<BoxType>()) {
840     auto rank =
841         dyn_cast_ptrOrBoxEleTy(boxTy).cast<SequenceType>().getDimension();
842     auto idxTy = rewriter.getIndexType();
843     for (decltype(rank) dim = 0; dim < rank; ++dim) {
844       auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim);
845       auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy,
846                                                 arrLoad.getMemref(), dimVal);
847       result.emplace_back(dimInfo.getResult(1));
848     }
849     if (!arrLoad.getShape()) {
850       auto shapeType = ShapeType::get(rewriter.getContext(), rank);
851       return rewriter.create<ShapeOp>(loc, shapeType, result);
852     }
853     auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>();
854     auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank);
855     llvm::SmallVector<mlir::Value> shapeShiftOperands;
856     for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) {
857       shapeShiftOperands.push_back(lb);
858       shapeShiftOperands.push_back(extent);
859     }
860     return rewriter.create<ShapeShiftOp>(loc, shapeShiftType,
861                                          shapeShiftOperands);
862   }
863   copyUsingSlice =
864       getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape());
865   return arrLoad.getShape();
866 }
867 
868 static mlir::Type toRefType(mlir::Type ty) {
869   if (fir::isa_ref_type(ty))
870     return ty;
871   return fir::ReferenceType::get(ty);
872 }
873 
874 static mlir::Value
875 genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy,
876           mlir::Type resTy, mlir::Value alloc, mlir::Value shape,
877           mlir::Value slice, mlir::ValueRange indices,
878           mlir::ValueRange typeparams, bool skipOrig = false) {
879   llvm::SmallVector<mlir::Value> originated;
880   if (skipOrig)
881     originated.assign(indices.begin(), indices.end());
882   else
883     originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(),
884                                                 shape, indices);
885   auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType());
886   assert(seqTy && seqTy.isa<fir::SequenceType>());
887   const auto dimension = seqTy.cast<fir::SequenceType>().getDimension();
888   mlir::Value result = rewriter.create<fir::ArrayCoorOp>(
889       loc, eleTy, alloc, shape, slice,
890       llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
891       typeparams);
892   if (dimension < originated.size())
893     result = rewriter.create<fir::CoordinateOp>(
894         loc, resTy, result,
895         llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
896   return result;
897 }
898 
899 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder,
900                                    ArrayLoadOp load, CharacterType charTy) {
901   auto charLenTy = builder.getCharacterLengthType();
902   if (charTy.hasDynamicLen()) {
903     if (load.getMemref().getType().isa<BoxType>()) {
904       // The loaded array is an emboxed value. Get the CHARACTER length from
905       // the box value.
906       auto eleSzInBytes =
907           builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref());
908       auto kindSize =
909           builder.getKindMap().getCharacterBitsize(charTy.getFKind());
910       auto kindByteSize =
911           builder.createIntegerConstant(loc, charLenTy, kindSize / 8);
912       return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes,
913                                                   kindByteSize);
914     }
915     // The loaded array is a (set of) unboxed values. If the CHARACTER's
916     // length is not a constant, it must be provided as a type parameter to
917     // the array_load.
918     auto typeparams = load.getTypeparams();
919     assert(typeparams.size() > 0 && "expected type parameters on array_load");
920     return typeparams.back();
921   }
922   // The typical case: the length of the CHARACTER is a compile-time
923   // constant that is encoded in the type information.
924   return builder.createIntegerConstant(loc, charLenTy, charTy.getLen());
925 }
926 /// Generate a shallow array copy. This is used for both copy-in and copy-out.
927 template <bool CopyIn>
928 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
929                   mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
930                   mlir::Value sliceOp, ArrayLoadOp arrLoad) {
931   auto insPt = rewriter.saveInsertionPoint();
932   llvm::SmallVector<mlir::Value> indices;
933   llvm::SmallVector<mlir::Value> extents;
934   bool copyUsingSlice =
935       getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp);
936   auto idxTy = rewriter.getIndexType();
937   // Build loop nest from column to row.
938   for (auto sh : llvm::reverse(extents)) {
939     auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh);
940     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
941     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
942     auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one);
943     auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one);
944     rewriter.setInsertionPointToStart(loop.getBody());
945     indices.push_back(loop.getInductionVar());
946   }
947   // Reverse the indices so they are in column-major order.
948   std::reverse(indices.begin(), indices.end());
949   auto typeparams = arrLoad.getTypeparams();
950   auto fromAddr = rewriter.create<ArrayCoorOp>(
951       loc, getEleTy(src.getType()), src, shapeOp,
952       CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
953       factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices),
954       typeparams);
955   auto toAddr = rewriter.create<ArrayCoorOp>(
956       loc, getEleTy(dst.getType()), dst, shapeOp,
957       !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
958       factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices),
959       typeparams);
960   auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType()));
961   auto module = toAddr->getParentOfType<mlir::ModuleOp>();
962   FirOpBuilder builder(rewriter, getKindMapping(module));
963   // Copy from (to) object to (from) temp copy of same object.
964   if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
965     auto len = getCharacterLen(loc, builder, arrLoad, charTy);
966     CharBoxValue toChar(toAddr, len);
967     CharBoxValue fromChar(fromAddr, len);
968     factory::genScalarAssignment(builder, loc, toChar, fromChar);
969   } else {
970     if (hasDynamicSize(eleTy))
971       TODO(loc, "copy element of dynamic size");
972     factory::genScalarAssignment(builder, loc, toAddr, fromAddr);
973   }
974   rewriter.restoreInsertionPoint(insPt);
975 }
976 
977 /// The array load may be either a boxed or unboxed value. If the value is
978 /// boxed, we read the type parameters from the boxed value.
979 static llvm::SmallVector<mlir::Value>
980 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter,
981                            ArrayLoadOp load) {
982   if (load.getTypeparams().empty()) {
983     auto eleTy =
984         unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType()));
985     if (hasDynamicSize(eleTy)) {
986       if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
987         assert(load.getMemref().getType().isa<BoxType>());
988         auto module = load->getParentOfType<mlir::ModuleOp>();
989         FirOpBuilder builder(rewriter, getKindMapping(module));
990         return {getCharacterLen(loc, builder, load, charTy)};
991       }
992       TODO(loc, "unhandled dynamic type parameters");
993     }
994     return {};
995   }
996   return load.getTypeparams();
997 }
998 
999 static llvm::SmallVector<mlir::Value>
1000 findNonconstantExtents(mlir::Type memrefTy,
1001                        llvm::ArrayRef<mlir::Value> extents) {
1002   llvm::SmallVector<mlir::Value> nce;
1003   auto arrTy = unwrapPassByRefType(memrefTy);
1004   auto seqTy = arrTy.cast<SequenceType>();
1005   for (auto [s, x] : llvm::zip(seqTy.getShape(), extents))
1006     if (s == SequenceType::getUnknownExtent())
1007       nce.emplace_back(x);
1008   if (extents.size() > seqTy.getShape().size())
1009     for (auto x : extents.drop_front(seqTy.getShape().size()))
1010       nce.emplace_back(x);
1011   return nce;
1012 }
1013 
1014 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any
1015 /// allocatable direct components of the array elements with an unallocated
1016 /// status. Returns the temporary address as well as a callback to generate the
1017 /// temporary clean-up once it has been used. The clean-up will take care of
1018 /// deallocating all the element allocatable components that may have been
1019 /// allocated while using the temporary.
1020 static std::pair<mlir::Value,
1021                  std::function<void(mlir::PatternRewriter &rewriter)>>
1022 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter,
1023                   ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents,
1024                   mlir::Value shape) {
1025   mlir::Type baseType = load.getMemref().getType();
1026   llvm::SmallVector<mlir::Value> nonconstantExtents =
1027       findNonconstantExtents(baseType, extents);
1028   llvm::SmallVector<mlir::Value> typeParams =
1029       genArrayLoadTypeParameters(loc, rewriter, load);
1030   mlir::Value allocmem = rewriter.create<AllocMemOp>(
1031       loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents);
1032   mlir::Type eleType =
1033       fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType));
1034   if (fir::isRecordWithAllocatableMember(eleType)) {
1035     // The allocatable component descriptors need to be set to a clean
1036     // deallocated status before anything is done with them.
1037     mlir::Value box = rewriter.create<fir::EmboxOp>(
1038         loc, fir::BoxType::get(baseType), allocmem, shape,
1039         /*slice=*/mlir::Value{}, typeParams);
1040     auto module = load->getParentOfType<mlir::ModuleOp>();
1041     FirOpBuilder builder(rewriter, getKindMapping(module));
1042     runtime::genDerivedTypeInitialize(builder, loc, box);
1043     // Any allocatable component that may have been allocated must be
1044     // deallocated during the clean-up.
1045     auto cleanup = [=](mlir::PatternRewriter &r) {
1046       FirOpBuilder builder(r, getKindMapping(module));
1047       runtime::genDerivedTypeDestroy(builder, loc, box);
1048       r.create<FreeMemOp>(loc, allocmem);
1049     };
1050     return {allocmem, cleanup};
1051   }
1052   auto cleanup = [=](mlir::PatternRewriter &r) {
1053     r.create<FreeMemOp>(loc, allocmem);
1054   };
1055   return {allocmem, cleanup};
1056 }
1057 
1058 namespace {
1059 /// Conversion of fir.array_update and fir.array_modify Ops.
1060 /// If there is a conflict for the update, then we need to perform a
1061 /// copy-in/copy-out to preserve the original values of the array. If there is
1062 /// no conflict, then it is save to eschew making any copies.
1063 template <typename ArrayOp>
1064 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
1065 public:
1066   // TODO: Implement copy/swap semantics?
1067   explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
1068                                      const ArrayCopyAnalysis &a,
1069                                      const OperationUseMapT &m)
1070       : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
1071 
1072   /// The array_access, \p access, is to be to a cloned copy due to a potential
1073   /// conflict. Uses copy-in/copy-out semantics and not copy/swap.
1074   mlir::Value referenceToClone(mlir::Location loc,
1075                                mlir::PatternRewriter &rewriter,
1076                                ArrayOp access) const {
1077     LLVM_DEBUG(llvm::dbgs()
1078                << "generating copy-in/copy-out loops for " << access << '\n');
1079     auto *op = access.getOperation();
1080     auto *loadOp = useMap.lookup(op);
1081     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1082     auto eleTy = access.getType();
1083     rewriter.setInsertionPoint(loadOp);
1084     // Copy in.
1085     llvm::SmallVector<mlir::Value> extents;
1086     bool copyUsingSlice = false;
1087     auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1088                                               copyUsingSlice);
1089     auto [allocmem, genTempCleanUp] =
1090         allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1091     genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1092                                   load.getMemref(), shapeOp, load.getSlice(),
1093                                   load);
1094     // Generate the reference for the access.
1095     rewriter.setInsertionPoint(op);
1096     auto coor =
1097         genCoorOp(rewriter, loc, getEleTy(load.getType()), eleTy, allocmem,
1098                   shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
1099                   access.getIndices(), load.getTypeparams(),
1100                   access->hasAttr(factory::attrFortranArrayOffsets()));
1101     // Copy out.
1102     auto *storeOp = useMap.lookup(loadOp);
1103     auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1104     rewriter.setInsertionPoint(storeOp);
1105     // Copy out.
1106     genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(),
1107                                    allocmem, shapeOp, store.getSlice(), load);
1108     genTempCleanUp(rewriter);
1109     return coor;
1110   }
1111 
1112   /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
1113   /// temp and the LHS if the analysis found potential overlaps between the RHS
1114   /// and LHS arrays. The element copy generator must be provided in \p
1115   /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
1116   /// Returns the address of the LHS element inside the loop and the LHS
1117   /// ArrayLoad result.
1118   std::pair<mlir::Value, mlir::Value>
1119   materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
1120                         ArrayOp update,
1121                         const std::function<void(mlir::Value)> &assignElement,
1122                         mlir::Type lhsEltRefType) const {
1123     auto *op = update.getOperation();
1124     auto *loadOp = useMap.lookup(op);
1125     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1126     LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
1127     if (analysis.hasPotentialConflict(loadOp)) {
1128       // If there is a conflict between the arrays, then we copy the lhs array
1129       // to a temporary, update the temporary, and copy the temporary back to
1130       // the lhs array. This yields Fortran's copy-in copy-out array semantics.
1131       LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
1132       rewriter.setInsertionPoint(loadOp);
1133       // Copy in.
1134       llvm::SmallVector<mlir::Value> extents;
1135       bool copyUsingSlice = false;
1136       auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1137                                                 copyUsingSlice);
1138       auto [allocmem, genTempCleanUp] =
1139           allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1140 
1141       genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1142                                     load.getMemref(), shapeOp, load.getSlice(),
1143                                     load);
1144       rewriter.setInsertionPoint(op);
1145       auto coor = genCoorOp(
1146           rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
1147           shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
1148           update.getIndices(), load.getTypeparams(),
1149           update->hasAttr(factory::attrFortranArrayOffsets()));
1150       assignElement(coor);
1151       auto *storeOp = useMap.lookup(loadOp);
1152       auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1153       rewriter.setInsertionPoint(storeOp);
1154       // Copy out.
1155       genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter,
1156                                      store.getMemref(), allocmem, shapeOp,
1157                                      store.getSlice(), load);
1158       genTempCleanUp(rewriter);
1159       return {coor, load.getResult()};
1160     }
1161     // Otherwise, when there is no conflict (a possible loop-carried
1162     // dependence), the lhs array can be updated in place.
1163     LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
1164     rewriter.setInsertionPoint(op);
1165     auto coorTy = getEleTy(load.getType());
1166     auto coor = genCoorOp(rewriter, loc, coorTy, lhsEltRefType,
1167                           load.getMemref(), load.getShape(), load.getSlice(),
1168                           update.getIndices(), load.getTypeparams(),
1169                           update->hasAttr(factory::attrFortranArrayOffsets()));
1170     assignElement(coor);
1171     return {coor, load.getResult()};
1172   }
1173 
1174 protected:
1175   const ArrayCopyAnalysis &analysis;
1176   const OperationUseMapT &useMap;
1177 };
1178 
1179 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
1180 public:
1181   explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
1182                                  const ArrayCopyAnalysis &a,
1183                                  const OperationUseMapT &m)
1184       : ArrayUpdateConversionBase{ctx, a, m} {}
1185 
1186   mlir::LogicalResult
1187   matchAndRewrite(ArrayUpdateOp update,
1188                   mlir::PatternRewriter &rewriter) const override {
1189     auto loc = update.getLoc();
1190     auto assignElement = [&](mlir::Value coor) {
1191       auto input = update.getMerge();
1192       if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) {
1193         emitFatalError(loc, "array_update on references not supported");
1194       } else {
1195         rewriter.create<fir::StoreOp>(loc, input, coor);
1196       }
1197     };
1198     auto lhsEltRefType = toRefType(update.getMerge().getType());
1199     auto [_, lhsLoadResult] = materializeAssignment(
1200         loc, rewriter, update, assignElement, lhsEltRefType);
1201     update.replaceAllUsesWith(lhsLoadResult);
1202     rewriter.replaceOp(update, lhsLoadResult);
1203     return mlir::success();
1204   }
1205 };
1206 
1207 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
1208 public:
1209   explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
1210                                  const ArrayCopyAnalysis &a,
1211                                  const OperationUseMapT &m)
1212       : ArrayUpdateConversionBase{ctx, a, m} {}
1213 
1214   mlir::LogicalResult
1215   matchAndRewrite(ArrayModifyOp modify,
1216                   mlir::PatternRewriter &rewriter) const override {
1217     auto loc = modify.getLoc();
1218     auto assignElement = [](mlir::Value) {
1219       // Assignment already materialized by lowering using lhs element address.
1220     };
1221     auto lhsEltRefType = modify.getResult(0).getType();
1222     auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
1223         loc, rewriter, modify, assignElement, lhsEltRefType);
1224     modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1225     rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1226     return mlir::success();
1227   }
1228 };
1229 
1230 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
1231 public:
1232   explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
1233                                 const OperationUseMapT &m)
1234       : OpRewritePattern{ctx}, useMap{m} {}
1235 
1236   mlir::LogicalResult
1237   matchAndRewrite(ArrayFetchOp fetch,
1238                   mlir::PatternRewriter &rewriter) const override {
1239     auto *op = fetch.getOperation();
1240     rewriter.setInsertionPoint(op);
1241     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1242     auto loc = fetch.getLoc();
1243     auto coor =
1244         genCoorOp(rewriter, loc, getEleTy(load.getType()),
1245                   toRefType(fetch.getType()), load.getMemref(), load.getShape(),
1246                   load.getSlice(), fetch.getIndices(), load.getTypeparams(),
1247                   fetch->hasAttr(factory::attrFortranArrayOffsets()));
1248     if (isa_ref_type(fetch.getType()))
1249       rewriter.replaceOp(fetch, coor);
1250     else
1251       rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
1252     return mlir::success();
1253   }
1254 
1255 private:
1256   const OperationUseMapT &useMap;
1257 };
1258 
1259 /// As array_access op is like an array_fetch op, except that it does not imply
1260 /// a load op. (It operates in the reference domain.)
1261 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> {
1262 public:
1263   explicit ArrayAccessConversion(mlir::MLIRContext *ctx,
1264                                  const ArrayCopyAnalysis &a,
1265                                  const OperationUseMapT &m)
1266       : ArrayUpdateConversionBase{ctx, a, m} {}
1267 
1268   mlir::LogicalResult
1269   matchAndRewrite(ArrayAccessOp access,
1270                   mlir::PatternRewriter &rewriter) const override {
1271     auto *op = access.getOperation();
1272     auto loc = access.getLoc();
1273     if (analysis.inAmendAccessSet(op)) {
1274       // This array_access is associated with an array_amend and there is a
1275       // conflict. Make a copy to store into.
1276       auto result = referenceToClone(loc, rewriter, access);
1277       access.replaceAllUsesWith(result);
1278       rewriter.replaceOp(access, result);
1279       return mlir::success();
1280     }
1281     rewriter.setInsertionPoint(op);
1282     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1283     auto coor = genCoorOp(rewriter, loc, getEleTy(load.getType()),
1284                           toRefType(access.getType()), load.getMemref(),
1285                           load.getShape(), load.getSlice(), access.getIndices(),
1286                           load.getTypeparams(),
1287                           access->hasAttr(factory::attrFortranArrayOffsets()));
1288     rewriter.replaceOp(access, coor);
1289     return mlir::success();
1290   }
1291 };
1292 
1293 /// An array_amend op is a marker to record which array access is being used to
1294 /// update an array value. After this pass runs, an array_amend has no
1295 /// semantics. We rewrite these to undefined values here to remove them while
1296 /// preserving SSA form.
1297 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> {
1298 public:
1299   explicit ArrayAmendConversion(mlir::MLIRContext *ctx)
1300       : OpRewritePattern{ctx} {}
1301 
1302   mlir::LogicalResult
1303   matchAndRewrite(ArrayAmendOp amend,
1304                   mlir::PatternRewriter &rewriter) const override {
1305     auto *op = amend.getOperation();
1306     rewriter.setInsertionPoint(op);
1307     auto loc = amend.getLoc();
1308     auto undef = rewriter.create<UndefOp>(loc, amend.getType());
1309     rewriter.replaceOp(amend, undef.getResult());
1310     return mlir::success();
1311   }
1312 };
1313 
1314 class ArrayValueCopyConverter
1315     : public ArrayValueCopyBase<ArrayValueCopyConverter> {
1316 public:
1317   void runOnOperation() override {
1318     auto func = getOperation();
1319     LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
1320                             << func.getName() << "'\n");
1321     auto *context = &getContext();
1322 
1323     // Perform the conflict analysis.
1324     const auto &analysis = getAnalysis<ArrayCopyAnalysis>();
1325     const auto &useMap = analysis.getUseMap();
1326 
1327     mlir::RewritePatternSet patterns1(context);
1328     patterns1.insert<ArrayFetchConversion>(context, useMap);
1329     patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
1330     patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
1331     patterns1.insert<ArrayAccessConversion>(context, analysis, useMap);
1332     patterns1.insert<ArrayAmendConversion>(context);
1333     mlir::ConversionTarget target(*context);
1334     target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
1335                            mlir::arith::ArithmeticDialect,
1336                            mlir::func::FuncDialect>();
1337     target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp,
1338                         ArrayUpdateOp, ArrayModifyOp>();
1339     // Rewrite the array fetch and array update ops.
1340     if (mlir::failed(
1341             mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
1342       mlir::emitError(mlir::UnknownLoc::get(context),
1343                       "failure in array-value-copy pass, phase 1");
1344       signalPassFailure();
1345     }
1346 
1347     mlir::RewritePatternSet patterns2(context);
1348     patterns2.insert<ArrayLoadConversion>(context);
1349     patterns2.insert<ArrayMergeStoreConversion>(context);
1350     target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
1351     if (mlir::failed(
1352             mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
1353       mlir::emitError(mlir::UnknownLoc::get(context),
1354                       "failure in array-value-copy pass, phase 2");
1355       signalPassFailure();
1356     }
1357   }
1358 };
1359 } // namespace
1360 
1361 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() {
1362   return std::make_unique<ArrayValueCopyConverter>();
1363 }
1364