xref: /llvm-project/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp (revision 6c8d8d10455a3d38a92102ae3a94e1c06d3eb363)
1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/Array.h"
10 #include "flang/Optimizer/Builder/BoxValue.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/Factory.h"
13 #include "flang/Optimizer/Builder/Runtime/Derived.h"
14 #include "flang/Optimizer/Builder/Todo.h"
15 #include "flang/Optimizer/Dialect/FIRDialect.h"
16 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
17 #include "flang/Optimizer/Support/FIRContext.h"
18 #include "flang/Optimizer/Transforms/Passes.h"
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/Support/Debug.h"
23 
24 namespace fir {
25 #define GEN_PASS_DEF_ARRAYVALUECOPY
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
27 } // namespace fir
28 
29 #define DEBUG_TYPE "flang-array-value-copy"
30 
31 using namespace fir;
32 using namespace mlir;
33 
34 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
35 
36 namespace {
37 
38 /// Array copy analysis.
39 /// Perform an interference analysis between array values.
40 ///
41 /// Lowering will generate a sequence of the following form.
42 /// ```mlir
43 ///   %a_1 = fir.array_load %array_1(%shape) : ...
44 ///   ...
45 ///   %a_j = fir.array_load %array_j(%shape) : ...
46 ///   ...
47 ///   %a_n = fir.array_load %array_n(%shape) : ...
48 ///     ...
49 ///     %v_i = fir.array_fetch %a_i, ...
50 ///     %a_j1 = fir.array_update %a_j, ...
51 ///     ...
52 ///   fir.array_merge_store %a_j, %a_jn to %array_j : ...
53 /// ```
54 ///
55 /// The analysis is to determine if there are any conflicts. A conflict is when
56 /// one the following cases occurs.
57 ///
58 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
59 /// loaded from the same array memory reference (array_j) but with a different
60 /// shape as the other array values a_i, where i != j. [Possible overlapping
61 /// arrays.]
62 ///
63 /// 2. There is either an array_fetch or array_update of a_j with a different
64 /// set of index values. [Possible loop-carried dependence.]
65 ///
66 /// If none of the array values overlap in storage and the accesses are not
67 /// loop-carried, then the arrays are conflict-free and no copies are required.
68 class ArrayCopyAnalysis {
69 public:
70   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis)
71 
72   using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
73   using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
74   using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>;
75   using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>;
76 
77   ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); }
78 
79   mlir::Operation *getOperation() const { return operation; }
80 
81   /// Return true iff the `array_merge_store` has potential conflicts.
82   bool hasPotentialConflict(mlir::Operation *op) const {
83     LLVM_DEBUG(llvm::dbgs()
84                << "looking for a conflict on " << *op
85                << " and the set has a total of " << conflicts.size() << '\n');
86     return conflicts.contains(op);
87   }
88 
89   /// Return the use map.
90   /// The use map maps array access, amend, fetch and update operations back to
91   /// the array load that is the original source of the array value.
92   /// It maps an array_load to an array_merge_store, if and only if the loaded
93   /// array value has pending modifications to be merged.
94   const OperationUseMapT &getUseMap() const { return useMap; }
95 
96   /// Return the set of array_access ops directly associated with array_amend
97   /// ops.
98   bool inAmendAccessSet(mlir::Operation *op) const {
99     return amendAccesses.count(op);
100   }
101 
102   /// For ArrayLoad `load`, return the transitive set of all OpOperands.
103   UseSetT getLoadUseSet(mlir::Operation *load) const {
104     assert(loadMapSets.count(load) && "analysis missed an array load?");
105     return loadMapSets.lookup(load);
106   }
107 
108   void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions,
109                      ArrayLoadOp load);
110 
111 private:
112   void construct(mlir::Operation *topLevelOp);
113 
114   mlir::Operation *operation; // operation that analysis ran upon
115   ConflictSetT conflicts;     // set of conflicts (loads and merge stores)
116   OperationUseMapT useMap;
117   LoadMapSetsT loadMapSets;
118   // Set of array_access ops associated with array_amend ops.
119   AmendAccessSetT amendAccesses;
120 };
121 } // namespace
122 
123 namespace {
124 /// Helper class to collect all array operations that produced an array value.
125 class ReachCollector {
126 public:
127   ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
128                  mlir::Region *loopRegion)
129       : reach{reach}, loopRegion{loopRegion} {}
130 
131   void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) {
132     if (range.empty()) {
133       collectArrayMentionFrom(op, mlir::Value{});
134       return;
135     }
136     for (mlir::Value v : range)
137       collectArrayMentionFrom(v);
138   }
139 
140   // Collect all the array_access ops in `block`. This recursively looks into
141   // blocks in ops with regions.
142   // FIXME: This is temporarily relying on the array_amend appearing in a
143   // do_loop Region.  This phase ordering assumption can be eliminated by using
144   // dominance information to find the array_access ops or by scanning the
145   // transitive closure of the amending array_access's users and the defs that
146   // reach them.
147   void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result,
148                        mlir::Block *block) {
149     for (auto &op : *block) {
150       if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) {
151         LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n');
152         result.push_back(access);
153         continue;
154       }
155       for (auto &region : op.getRegions())
156         for (auto &bb : region.getBlocks())
157           collectAccesses(result, &bb);
158     }
159   }
160 
161   void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) {
162     // `val` is defined by an Op, process the defining Op.
163     // If `val` is defined by a region containing Op, we want to drill down
164     // and through that Op's region(s).
165     LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
166     auto popFn = [&](auto rop) {
167       assert(val && "op must have a result value");
168       auto resNum = val.cast<mlir::OpResult>().getResultNumber();
169       llvm::SmallVector<mlir::Value> results;
170       rop.resultToSourceOps(results, resNum);
171       for (auto u : results)
172         collectArrayMentionFrom(u);
173     };
174     if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) {
175       popFn(rop);
176       return;
177     }
178     if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) {
179       popFn(rop);
180       return;
181     }
182     if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
183       popFn(rop);
184       return;
185     }
186     if (auto box = mlir::dyn_cast<EmboxOp>(op)) {
187       for (auto *user : box.getMemref().getUsers())
188         if (user != op)
189           collectArrayMentionFrom(user, user->getResults());
190       return;
191     }
192     if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
193       if (opIsInsideLoops(mergeStore))
194         collectArrayMentionFrom(mergeStore.getSequence());
195       return;
196     }
197 
198     if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
199       // Look for any stores inside the loops, and collect an array operation
200       // that produced the value being stored to it.
201       for (auto *user : op->getUsers())
202         if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
203           if (opIsInsideLoops(store))
204             collectArrayMentionFrom(store.getValue());
205       return;
206     }
207 
208     // Scan the uses of amend's memref
209     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) {
210       reach.push_back(op);
211       llvm::SmallVector<ArrayAccessOp> accesses;
212       collectAccesses(accesses, op->getBlock());
213       for (auto access : accesses)
214         collectArrayMentionFrom(access.getResult());
215     }
216 
217     // Otherwise, Op does not contain a region so just chase its operands.
218     if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp,
219                   ArrayFetchOp>(op)) {
220       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
221       reach.push_back(op);
222     }
223 
224     // Include all array_access ops using an array_load.
225     if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op))
226       for (auto *user : arrLd.getResult().getUsers())
227         if (mlir::isa<ArrayAccessOp>(user)) {
228           LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n");
229           reach.push_back(user);
230         }
231 
232     // Array modify assignment is performed on the result. So the analysis must
233     // look at the what is done with the result.
234     if (mlir::isa<ArrayModifyOp>(op))
235       for (auto *user : op->getResult(0).getUsers())
236         followUsers(user);
237 
238     if (mlir::isa<fir::CallOp>(op)) {
239       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
240       reach.push_back(op);
241     }
242 
243     for (auto u : op->getOperands())
244       collectArrayMentionFrom(u);
245   }
246 
247   void collectArrayMentionFrom(mlir::BlockArgument ba) {
248     auto *parent = ba.getOwner()->getParentOp();
249     // If inside an Op holding a region, the block argument corresponds to an
250     // argument passed to the containing Op.
251     auto popFn = [&](auto rop) {
252       collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
253     };
254     if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
255       popFn(rop);
256       return;
257     }
258     if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
259       popFn(rop);
260       return;
261     }
262     // Otherwise, a block argument is provided via the pred blocks.
263     for (auto *pred : ba.getOwner()->getPredecessors()) {
264       auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
265       collectArrayMentionFrom(u);
266     }
267   }
268 
269   // Recursively trace operands to find all array operations relating to the
270   // values merged.
271   void collectArrayMentionFrom(mlir::Value val) {
272     if (!val || visited.contains(val))
273       return;
274     visited.insert(val);
275 
276     // Process a block argument.
277     if (auto ba = val.dyn_cast<mlir::BlockArgument>()) {
278       collectArrayMentionFrom(ba);
279       return;
280     }
281 
282     // Process an Op.
283     if (auto *op = val.getDefiningOp()) {
284       collectArrayMentionFrom(op, val);
285       return;
286     }
287 
288     emitFatalError(val.getLoc(), "unhandled value");
289   }
290 
291   /// Return all ops that produce the array value that is stored into the
292   /// `array_merge_store`.
293   static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
294                              mlir::Value seq) {
295     reach.clear();
296     mlir::Region *loopRegion = nullptr;
297     if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp()))
298       loopRegion = &doLoop->getRegion(0);
299     ReachCollector collector(reach, loopRegion);
300     collector.collectArrayMentionFrom(seq);
301   }
302 
303 private:
304   /// Is \op inside the loop nest region ?
305   /// FIXME: replace this structural dependence with graph properties.
306   bool opIsInsideLoops(mlir::Operation *op) const {
307     auto *region = op->getParentRegion();
308     while (region) {
309       if (region == loopRegion)
310         return true;
311       region = region->getParentRegion();
312     }
313     return false;
314   }
315 
316   /// Recursively trace the use of an operation results, calling
317   /// collectArrayMentionFrom on the direct and indirect user operands.
318   void followUsers(mlir::Operation *op) {
319     for (auto userOperand : op->getOperands())
320       collectArrayMentionFrom(userOperand);
321     // Go through potential converts/coordinate_op.
322     for (auto indirectUser : op->getUsers())
323       followUsers(indirectUser);
324   }
325 
326   llvm::SmallVectorImpl<mlir::Operation *> &reach;
327   llvm::SmallPtrSet<mlir::Value, 16> visited;
328   /// Region of the loops nest that produced the array value.
329   mlir::Region *loopRegion;
330 };
331 } // namespace
332 
333 /// Find all the array operations that access the array value that is loaded by
334 /// the array load operation, `load`.
335 void ArrayCopyAnalysis::arrayMentions(
336     llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) {
337   mentions.clear();
338   auto lmIter = loadMapSets.find(load);
339   if (lmIter != loadMapSets.end()) {
340     for (auto *opnd : lmIter->second) {
341       auto *owner = opnd->getOwner();
342       if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
343                     ArrayModifyOp>(owner))
344         mentions.push_back(owner);
345     }
346     return;
347   }
348 
349   UseSetT visited;
350   llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
351 
352   auto appendToQueue = [&](mlir::Value val) {
353     for (auto &use : val.getUses())
354       if (!visited.count(&use)) {
355         visited.insert(&use);
356         queue.push_back(&use);
357       }
358   };
359 
360   // Build the set of uses of `original`.
361   // let USES = { uses of original fir.load }
362   appendToQueue(load);
363 
364   // Process the worklist until done.
365   while (!queue.empty()) {
366     mlir::OpOperand *operand = queue.pop_back_val();
367     mlir::Operation *owner = operand->getOwner();
368     if (!owner)
369       continue;
370     auto structuredLoop = [&](auto ro) {
371       if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
372         int64_t arg = blockArg.getArgNumber();
373         mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1);
374         appendToQueue(output);
375         appendToQueue(blockArg);
376       }
377     };
378     // TODO: this need to be updated to use the control-flow interface.
379     auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
380       if (operands.empty())
381         return;
382 
383       // Check if this operand is within the range.
384       unsigned operandIndex = operand->getOperandNumber();
385       unsigned operandsStart = operands.getBeginOperandIndex();
386       if (operandIndex < operandsStart ||
387           operandIndex >= (operandsStart + operands.size()))
388         return;
389 
390       // Index the successor.
391       unsigned argIndex = operandIndex - operandsStart;
392       appendToQueue(dest->getArgument(argIndex));
393     };
394     // Thread uses into structured loop bodies and return value uses.
395     if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
396       structuredLoop(ro);
397     } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
398       structuredLoop(ro);
399     } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
400       // Thread any uses of fir.if that return the marked array value.
401       mlir::Operation *parent = rs->getParentRegion()->getParentOp();
402       if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent))
403         appendToQueue(ifOp.getResult(operand->getOperandNumber()));
404     } else if (mlir::isa<ArrayFetchOp>(owner)) {
405       // Keep track of array value fetches.
406       LLVM_DEBUG(llvm::dbgs()
407                  << "add fetch {" << *owner << "} to array value set\n");
408       mentions.push_back(owner);
409     } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
410       // Keep track of array value updates and thread the return value uses.
411       LLVM_DEBUG(llvm::dbgs()
412                  << "add update {" << *owner << "} to array value set\n");
413       mentions.push_back(owner);
414       appendToQueue(update.getResult());
415     } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
416       // Keep track of array value modification and thread the return value
417       // uses.
418       LLVM_DEBUG(llvm::dbgs()
419                  << "add modify {" << *owner << "} to array value set\n");
420       mentions.push_back(owner);
421       appendToQueue(update.getResult(1));
422     } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) {
423       mentions.push_back(owner);
424     } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) {
425       mentions.push_back(owner);
426       appendToQueue(amend.getResult());
427     } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
428       branchOp(br.getDest(), br.getDestOperands());
429     } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
430       branchOp(br.getTrueDest(), br.getTrueOperands());
431       branchOp(br.getFalseDest(), br.getFalseOperands());
432     } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
433       // do nothing
434     } else {
435       llvm::report_fatal_error("array value reached unexpected op");
436     }
437   }
438   loadMapSets.insert({load, visited});
439 }
440 
441 static bool hasPointerType(mlir::Type type) {
442   if (auto boxTy = type.dyn_cast<BoxType>())
443     type = boxTy.getEleTy();
444   return type.isa<fir::PointerType>();
445 }
446 
447 // This is a NF performance hack. It makes a simple test that the slices of the
448 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive.
449 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) {
450   // If the same array_load, then no further testing is warranted.
451   if (ld.getResult() == st.getOriginal())
452     return false;
453 
454   auto getSliceOp = [](mlir::Value val) -> SliceOp {
455     if (!val)
456       return {};
457     auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp());
458     if (!sliceOp)
459       return {};
460     return sliceOp;
461   };
462 
463   auto ldSlice = getSliceOp(ld.getSlice());
464   auto stSlice = getSliceOp(st.getSlice());
465   if (!ldSlice || !stSlice)
466     return false;
467 
468   // Resign on subobject slices.
469   if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() ||
470       !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty())
471     return false;
472 
473   // Crudely test that the two slices do not overlap by looking for the
474   // following general condition. If the slices look like (i:j) and (j+1:k) then
475   // these ranges do not overlap. The addend must be a constant.
476   auto ldTriples = ldSlice.getTriples();
477   auto stTriples = stSlice.getTriples();
478   const auto size = ldTriples.size();
479   if (size != stTriples.size())
480     return false;
481 
482   auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
483     auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
484       auto *op = v.getDefiningOp();
485       while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op))
486         op = conv.getValue().getDefiningOp();
487       return op;
488     };
489 
490     auto isPositiveConstant = [](mlir::Value v) -> bool {
491       if (auto conOp =
492               mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
493         if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>())
494           return iattr.getInt() > 0;
495       return false;
496     };
497 
498     auto *op1 = removeConvert(v1);
499     auto *op2 = removeConvert(v2);
500     if (!op1 || !op2)
501       return false;
502     if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
503       if ((addi.getLhs().getDefiningOp() == op1 &&
504            isPositiveConstant(addi.getRhs())) ||
505           (addi.getRhs().getDefiningOp() == op1 &&
506            isPositiveConstant(addi.getLhs())))
507         return true;
508     if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
509       if (subi.getLhs().getDefiningOp() == op2 &&
510           isPositiveConstant(subi.getRhs()))
511         return true;
512     return false;
513   };
514 
515   for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) {
516     // If both are loop invariant, skip to the next triple.
517     if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) &&
518         mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) {
519       // Unless either is a vector index, then be conservative.
520       if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) ||
521           mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp()))
522         return false;
523       continue;
524     }
525     // If identical, skip to the next triple.
526     if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] &&
527         ldTriples[i + 2] == stTriples[i + 2])
528       continue;
529     // If ubound and lbound are the same with a constant offset, skip to the
530     // next triple.
531     if (displacedByConstant(ldTriples[i + 1], stTriples[i]) ||
532         displacedByConstant(stTriples[i + 1], ldTriples[i]))
533       continue;
534     return false;
535   }
536   LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld
537                           << " and " << st << ", which is not a conflict\n");
538   return true;
539 }
540 
541 /// Is there a conflict between the array value that was updated and to be
542 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
543 /// the updated value?
544 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
545                            ArrayMergeStoreOp st) {
546   mlir::Value load;
547   mlir::Value addr = st.getMemref();
548   const bool storeHasPointerType = hasPointerType(addr.getType());
549   for (auto *op : reach)
550     if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) {
551       mlir::Type ldTy = ld.getMemref().getType();
552       if (ld.getMemref() == addr) {
553         if (mutuallyExclusiveSliceRange(ld, st))
554           continue;
555         if (ld.getResult() != st.getOriginal())
556           return true;
557         if (load) {
558           // TODO: extend this to allow checking if the first `load` and this
559           // `ld` are mutually exclusive accesses but not identical.
560           return true;
561         }
562         load = ld;
563       } else if ((hasPointerType(ldTy) || storeHasPointerType)) {
564         // TODO: Use target attribute to restrict this case further.
565         // TODO: Check if types can also allow ruling out some cases. For now,
566         // the fact that equivalences is using pointer attribute to enforce
567         // aliasing is preventing any attempt to do so, and in general, it may
568         // be wrong to use this if any of the types is a complex or a derived
569         // for which it is possible to create a pointer to a part with a
570         // different type than the whole, although this deserve some more
571         // investigation because existing compiler behavior seem to diverge
572         // here.
573         return true;
574       }
575     }
576   return false;
577 }
578 
579 /// Is there an access vector conflict on the array being merged into? If the
580 /// access vectors diverge, then assume that there are potentially overlapping
581 /// loop-carried references.
582 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) {
583   if (mentions.size() < 2)
584     return false;
585   llvm::SmallVector<mlir::Value> indices;
586   LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size()
587                           << " mentions on the list\n");
588   bool valSeen = false;
589   bool refSeen = false;
590   for (auto *op : mentions) {
591     llvm::SmallVector<mlir::Value> compareVector;
592     if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
593       valSeen = true;
594       if (indices.empty()) {
595         indices = u.getIndices();
596         continue;
597       }
598       compareVector = u.getIndices();
599     } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
600       valSeen = true;
601       if (indices.empty()) {
602         indices = f.getIndices();
603         continue;
604       }
605       compareVector = f.getIndices();
606     } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
607       valSeen = true;
608       if (indices.empty()) {
609         indices = f.getIndices();
610         continue;
611       }
612       compareVector = f.getIndices();
613     } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) {
614       refSeen = true;
615       if (indices.empty()) {
616         indices = f.getIndices();
617         continue;
618       }
619       compareVector = f.getIndices();
620     } else if (mlir::isa<ArrayAmendOp>(op)) {
621       refSeen = true;
622       continue;
623     } else {
624       mlir::emitError(op->getLoc(), "unexpected operation in analysis");
625     }
626     if (compareVector.size() != indices.size() ||
627         llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) {
628           return std::get<0>(pair) != std::get<1>(pair);
629         }))
630       return true;
631     LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
632   }
633   return valSeen && refSeen;
634 }
635 
636 /// With element-by-reference semantics, an amended array with more than once
637 /// access to the same loaded array are conservatively considered a conflict.
638 /// Note: the array copy can still be eliminated in subsequent optimizations.
639 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) {
640   LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size()
641                           << '\n');
642   if (mentions.size() < 3)
643     return false;
644   unsigned amendCount = 0;
645   unsigned accessCount = 0;
646   for (auto *op : mentions) {
647     if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) {
648       LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n");
649       return true;
650     }
651     if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) {
652       LLVM_DEBUG(llvm::dbgs()
653                  << "conflict: multiple accesses of array value\n");
654       return true;
655     }
656     if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) {
657       LLVM_DEBUG(llvm::dbgs()
658                  << "conflict: array value has both uses by-value and uses "
659                     "by-reference. conservative assumption.\n");
660       return true;
661     }
662   }
663   return false;
664 }
665 
666 static mlir::Operation *
667 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) {
668   for (auto *op : mentions)
669     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op))
670       return amend.getMemref().getDefiningOp();
671   return {};
672 }
673 
674 // Are any conflicts present? The conflicts detected here are described above.
675 static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
676                              llvm::ArrayRef<mlir::Operation *> mentions,
677                              ArrayMergeStoreOp st) {
678   return conflictOnLoad(reach, st) || conflictOnMerge(mentions);
679 }
680 
681 // Assume that any call to a function that uses host-associations will be
682 // modifying the output array.
683 static bool
684 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) {
685   return llvm::any_of(reaches, [](mlir::Operation *op) {
686     if (auto call = mlir::dyn_cast<fir::CallOp>(op))
687       if (auto callee =
688               call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) {
689         auto module = op->getParentOfType<mlir::ModuleOp>();
690         return hasHostAssociationArgument(
691             module.lookupSymbol<mlir::func::FuncOp>(callee));
692       }
693     return false;
694   });
695 }
696 
697 /// Constructor of the array copy analysis.
698 /// This performs the analysis and saves the intermediate results.
699 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) {
700   topLevelOp->walk([&](Operation *op) {
701     if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
702       llvm::SmallVector<mlir::Operation *> values;
703       ReachCollector::reachingValues(values, st.getSequence());
704       bool callConflict = conservativeCallConflict(values);
705       llvm::SmallVector<mlir::Operation *> mentions;
706       arrayMentions(mentions,
707                     mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp()));
708       bool conflict = conflictDetected(values, mentions, st);
709       bool refConflict = conflictOnReference(mentions);
710       if (callConflict || conflict || refConflict) {
711         LLVM_DEBUG(llvm::dbgs()
712                    << "CONFLICT: copies required for " << st << '\n'
713                    << "   adding conflicts on: " << *op << " and "
714                    << st.getOriginal() << '\n');
715         conflicts.insert(op);
716         conflicts.insert(st.getOriginal().getDefiningOp());
717         if (auto *access = amendingAccess(mentions))
718           amendAccesses.insert(access);
719       }
720       auto *ld = st.getOriginal().getDefiningOp();
721       LLVM_DEBUG(llvm::dbgs()
722                  << "map: adding {" << *ld << " -> " << st << "}\n");
723       useMap.insert({ld, op});
724     } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
725       llvm::SmallVector<mlir::Operation *> mentions;
726       arrayMentions(mentions, load);
727       LLVM_DEBUG(llvm::dbgs() << "process load: " << load
728                               << ", mentions: " << mentions.size() << '\n');
729       for (auto *acc : mentions) {
730         LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n');
731         if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
732                       ArrayModifyOp>(acc)) {
733           if (useMap.count(acc)) {
734             mlir::emitError(
735                 load.getLoc(),
736                 "The parallel semantics of multiple array_merge_stores per "
737                 "array_load are not supported.");
738             continue;
739           }
740           LLVM_DEBUG(llvm::dbgs()
741                      << "map: adding {" << *acc << "} -> {" << load << "}\n");
742           useMap.insert({acc, op});
743         }
744       }
745     }
746   });
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // Conversions for converting out of array value form.
751 //===----------------------------------------------------------------------===//
752 
753 namespace {
754 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
755 public:
756   using OpRewritePattern::OpRewritePattern;
757 
758   mlir::LogicalResult
759   matchAndRewrite(ArrayLoadOp load,
760                   mlir::PatternRewriter &rewriter) const override {
761     LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
762     rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
763     return mlir::success();
764   }
765 };
766 
767 class ArrayMergeStoreConversion
768     : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
769 public:
770   using OpRewritePattern::OpRewritePattern;
771 
772   mlir::LogicalResult
773   matchAndRewrite(ArrayMergeStoreOp store,
774                   mlir::PatternRewriter &rewriter) const override {
775     LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
776     rewriter.eraseOp(store);
777     return mlir::success();
778   }
779 };
780 } // namespace
781 
782 static mlir::Type getEleTy(mlir::Type ty) {
783   auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty));
784   // FIXME: keep ptr/heap/ref information.
785   return ReferenceType::get(eleTy);
786 }
787 
788 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
789 static bool getAdjustedExtents(mlir::Location loc,
790                                mlir::PatternRewriter &rewriter,
791                                ArrayLoadOp arrLoad,
792                                llvm::SmallVectorImpl<mlir::Value> &result,
793                                mlir::Value shape) {
794   bool copyUsingSlice = false;
795   auto *shapeOp = shape.getDefiningOp();
796   if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) {
797     auto e = s.getExtents();
798     result.insert(result.end(), e.begin(), e.end());
799   } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) {
800     auto e = s.getExtents();
801     result.insert(result.end(), e.begin(), e.end());
802   } else {
803     emitFatalError(loc, "not a fir.shape/fir.shape_shift op");
804   }
805   auto idxTy = rewriter.getIndexType();
806   if (factory::isAssumedSize(result)) {
807     // Use slice information to compute the extent of the column.
808     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
809     mlir::Value size = one;
810     if (mlir::Value sliceArg = arrLoad.getSlice()) {
811       if (auto sliceOp =
812               mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) {
813         auto triples = sliceOp.getTriples();
814         const std::size_t tripleSize = triples.size();
815         auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
816         fir::KindMapping kindMap = getKindMapping(module);
817         FirOpBuilder builder(rewriter, kindMap);
818         size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3],
819                                             triples[tripleSize - 2],
820                                             triples[tripleSize - 1], idxTy);
821         copyUsingSlice = true;
822       }
823     }
824     result[result.size() - 1] = size;
825   }
826   return copyUsingSlice;
827 }
828 
829 /// Place the extents of the array load, \p arrLoad, into \p result and
830 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is
831 /// loading a `!fir.box`, code will be generated to read the extents from the
832 /// boxed value, and the retunred shape Op will be built with the extents read
833 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or
834 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true
835 /// if slicing of the output array is to be done in the copy-in/copy-out rather
836 /// than in the elemental computation step.
837 static mlir::Value getOrReadExtentsAndShapeOp(
838     mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad,
839     llvm::SmallVectorImpl<mlir::Value> &result, bool &copyUsingSlice) {
840   assert(result.empty());
841   if (arrLoad->hasAttr(fir::getOptionalAttrName()))
842     fir::emitFatalError(
843         loc, "shapes from array load of OPTIONAL arrays must not be used");
844   if (auto boxTy = arrLoad.getMemref().getType().dyn_cast<BoxType>()) {
845     auto rank =
846         dyn_cast_ptrOrBoxEleTy(boxTy).cast<SequenceType>().getDimension();
847     auto idxTy = rewriter.getIndexType();
848     for (decltype(rank) dim = 0; dim < rank; ++dim) {
849       auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim);
850       auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy,
851                                                 arrLoad.getMemref(), dimVal);
852       result.emplace_back(dimInfo.getResult(1));
853     }
854     if (!arrLoad.getShape()) {
855       auto shapeType = ShapeType::get(rewriter.getContext(), rank);
856       return rewriter.create<ShapeOp>(loc, shapeType, result);
857     }
858     auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>();
859     auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank);
860     llvm::SmallVector<mlir::Value> shapeShiftOperands;
861     for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) {
862       shapeShiftOperands.push_back(lb);
863       shapeShiftOperands.push_back(extent);
864     }
865     return rewriter.create<ShapeShiftOp>(loc, shapeShiftType,
866                                          shapeShiftOperands);
867   }
868   copyUsingSlice =
869       getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape());
870   return arrLoad.getShape();
871 }
872 
873 static mlir::Type toRefType(mlir::Type ty) {
874   if (fir::isa_ref_type(ty))
875     return ty;
876   return fir::ReferenceType::get(ty);
877 }
878 
879 static llvm::SmallVector<mlir::Value>
880 getTypeParamsIfRawData(mlir::Location loc, FirOpBuilder &builder,
881                        ArrayLoadOp arrLoad, mlir::Type ty) {
882   if (ty.isa<BoxType>())
883     return {};
884   return fir::factory::getTypeParams(loc, builder, arrLoad);
885 }
886 
887 static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter,
888                              mlir::Location loc, mlir::Type eleTy,
889                              mlir::Type resTy, mlir::Value alloc,
890                              mlir::Value shape, mlir::Value slice,
891                              mlir::ValueRange indices, ArrayLoadOp load,
892                              bool skipOrig = false) {
893   llvm::SmallVector<mlir::Value> originated;
894   if (skipOrig)
895     originated.assign(indices.begin(), indices.end());
896   else
897     originated = factory::originateIndices(loc, rewriter, alloc.getType(),
898                                            shape, indices);
899   auto seqTy = dyn_cast_ptrOrBoxEleTy(alloc.getType());
900   assert(seqTy && seqTy.isa<SequenceType>());
901   const auto dimension = seqTy.cast<SequenceType>().getDimension();
902   auto module = load->getParentOfType<mlir::ModuleOp>();
903   fir::KindMapping kindMap = getKindMapping(module);
904   FirOpBuilder builder(rewriter, kindMap);
905   auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType());
906   mlir::Value result = rewriter.create<ArrayCoorOp>(
907       loc, eleTy, alloc, shape, slice,
908       llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
909       typeparams);
910   if (dimension < originated.size())
911     result = rewriter.create<fir::CoordinateOp>(
912         loc, resTy, result,
913         llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
914   return result;
915 }
916 
917 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder,
918                                    ArrayLoadOp load, CharacterType charTy) {
919   auto charLenTy = builder.getCharacterLengthType();
920   if (charTy.hasDynamicLen()) {
921     if (load.getMemref().getType().isa<BoxType>()) {
922       // The loaded array is an emboxed value. Get the CHARACTER length from
923       // the box value.
924       auto eleSzInBytes =
925           builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref());
926       auto kindSize =
927           builder.getKindMap().getCharacterBitsize(charTy.getFKind());
928       auto kindByteSize =
929           builder.createIntegerConstant(loc, charLenTy, kindSize / 8);
930       return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes,
931                                                   kindByteSize);
932     }
933     // The loaded array is a (set of) unboxed values. If the CHARACTER's
934     // length is not a constant, it must be provided as a type parameter to
935     // the array_load.
936     auto typeparams = load.getTypeparams();
937     assert(typeparams.size() > 0 && "expected type parameters on array_load");
938     return typeparams.back();
939   }
940   // The typical case: the length of the CHARACTER is a compile-time
941   // constant that is encoded in the type information.
942   return builder.createIntegerConstant(loc, charLenTy, charTy.getLen());
943 }
944 /// Generate a shallow array copy. This is used for both copy-in and copy-out.
945 template <bool CopyIn>
946 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
947                   mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
948                   mlir::Value sliceOp, ArrayLoadOp arrLoad) {
949   auto insPt = rewriter.saveInsertionPoint();
950   llvm::SmallVector<mlir::Value> indices;
951   llvm::SmallVector<mlir::Value> extents;
952   bool copyUsingSlice =
953       getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp);
954   auto idxTy = rewriter.getIndexType();
955   // Build loop nest from column to row.
956   for (auto sh : llvm::reverse(extents)) {
957     auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh);
958     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
959     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
960     auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one);
961     auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one);
962     rewriter.setInsertionPointToStart(loop.getBody());
963     indices.push_back(loop.getInductionVar());
964   }
965   // Reverse the indices so they are in column-major order.
966   std::reverse(indices.begin(), indices.end());
967   auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
968   fir::KindMapping kindMap = getKindMapping(module);
969   FirOpBuilder builder(rewriter, kindMap);
970   auto fromAddr = rewriter.create<ArrayCoorOp>(
971       loc, getEleTy(src.getType()), src, shapeOp,
972       CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
973       factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices),
974       getTypeParamsIfRawData(loc, builder, arrLoad, src.getType()));
975   auto toAddr = rewriter.create<ArrayCoorOp>(
976       loc, getEleTy(dst.getType()), dst, shapeOp,
977       !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
978       factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices),
979       getTypeParamsIfRawData(loc, builder, arrLoad, dst.getType()));
980   auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType()));
981   // Copy from (to) object to (from) temp copy of same object.
982   if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
983     auto len = getCharacterLen(loc, builder, arrLoad, charTy);
984     CharBoxValue toChar(toAddr, len);
985     CharBoxValue fromChar(fromAddr, len);
986     factory::genScalarAssignment(builder, loc, toChar, fromChar);
987   } else {
988     if (hasDynamicSize(eleTy))
989       TODO(loc, "copy element of dynamic size");
990     factory::genScalarAssignment(builder, loc, toAddr, fromAddr);
991   }
992   rewriter.restoreInsertionPoint(insPt);
993 }
994 
995 /// The array load may be either a boxed or unboxed value. If the value is
996 /// boxed, we read the type parameters from the boxed value.
997 static llvm::SmallVector<mlir::Value>
998 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter,
999                            ArrayLoadOp load) {
1000   if (load.getTypeparams().empty()) {
1001     auto eleTy =
1002         unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType()));
1003     if (hasDynamicSize(eleTy)) {
1004       if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
1005         assert(load.getMemref().getType().isa<BoxType>());
1006         auto module = load->getParentOfType<mlir::ModuleOp>();
1007         fir::KindMapping kindMap = getKindMapping(module);
1008         FirOpBuilder builder(rewriter, kindMap);
1009         return {getCharacterLen(loc, builder, load, charTy)};
1010       }
1011       TODO(loc, "unhandled dynamic type parameters");
1012     }
1013     return {};
1014   }
1015   return load.getTypeparams();
1016 }
1017 
1018 static llvm::SmallVector<mlir::Value>
1019 findNonconstantExtents(mlir::Type memrefTy,
1020                        llvm::ArrayRef<mlir::Value> extents) {
1021   llvm::SmallVector<mlir::Value> nce;
1022   auto arrTy = unwrapPassByRefType(memrefTy);
1023   auto seqTy = arrTy.cast<SequenceType>();
1024   for (auto [s, x] : llvm::zip(seqTy.getShape(), extents))
1025     if (s == SequenceType::getUnknownExtent())
1026       nce.emplace_back(x);
1027   if (extents.size() > seqTy.getShape().size())
1028     for (auto x : extents.drop_front(seqTy.getShape().size()))
1029       nce.emplace_back(x);
1030   return nce;
1031 }
1032 
1033 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any
1034 /// allocatable direct components of the array elements with an unallocated
1035 /// status. Returns the temporary address as well as a callback to generate the
1036 /// temporary clean-up once it has been used. The clean-up will take care of
1037 /// deallocating all the element allocatable components that may have been
1038 /// allocated while using the temporary.
1039 static std::pair<mlir::Value,
1040                  std::function<void(mlir::PatternRewriter &rewriter)>>
1041 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter,
1042                   ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents,
1043                   mlir::Value shape) {
1044   mlir::Type baseType = load.getMemref().getType();
1045   llvm::SmallVector<mlir::Value> nonconstantExtents =
1046       findNonconstantExtents(baseType, extents);
1047   llvm::SmallVector<mlir::Value> typeParams =
1048       genArrayLoadTypeParameters(loc, rewriter, load);
1049   mlir::Value allocmem = rewriter.create<AllocMemOp>(
1050       loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents);
1051   mlir::Type eleType =
1052       fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType));
1053   if (fir::isRecordWithAllocatableMember(eleType)) {
1054     // The allocatable component descriptors need to be set to a clean
1055     // deallocated status before anything is done with them.
1056     mlir::Value box = rewriter.create<fir::EmboxOp>(
1057         loc, fir::BoxType::get(baseType), allocmem, shape,
1058         /*slice=*/mlir::Value{}, typeParams);
1059     auto module = load->getParentOfType<mlir::ModuleOp>();
1060     fir::KindMapping kindMap = getKindMapping(module);
1061     FirOpBuilder builder(rewriter, kindMap);
1062     runtime::genDerivedTypeInitialize(builder, loc, box);
1063     // Any allocatable component that may have been allocated must be
1064     // deallocated during the clean-up.
1065     auto cleanup = [=](mlir::PatternRewriter &r) {
1066       fir::KindMapping kindMap = getKindMapping(module);
1067       FirOpBuilder builder(r, kindMap);
1068       runtime::genDerivedTypeDestroy(builder, loc, box);
1069       r.create<FreeMemOp>(loc, allocmem);
1070     };
1071     return {allocmem, cleanup};
1072   }
1073   auto cleanup = [=](mlir::PatternRewriter &r) {
1074     r.create<FreeMemOp>(loc, allocmem);
1075   };
1076   return {allocmem, cleanup};
1077 }
1078 
1079 namespace {
1080 /// Conversion of fir.array_update and fir.array_modify Ops.
1081 /// If there is a conflict for the update, then we need to perform a
1082 /// copy-in/copy-out to preserve the original values of the array. If there is
1083 /// no conflict, then it is save to eschew making any copies.
1084 template <typename ArrayOp>
1085 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
1086 public:
1087   // TODO: Implement copy/swap semantics?
1088   explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
1089                                      const ArrayCopyAnalysis &a,
1090                                      const OperationUseMapT &m)
1091       : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
1092 
1093   /// The array_access, \p access, is to be to a cloned copy due to a potential
1094   /// conflict. Uses copy-in/copy-out semantics and not copy/swap.
1095   mlir::Value referenceToClone(mlir::Location loc,
1096                                mlir::PatternRewriter &rewriter,
1097                                ArrayOp access) const {
1098     LLVM_DEBUG(llvm::dbgs()
1099                << "generating copy-in/copy-out loops for " << access << '\n');
1100     auto *op = access.getOperation();
1101     auto *loadOp = useMap.lookup(op);
1102     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1103     auto eleTy = access.getType();
1104     rewriter.setInsertionPoint(loadOp);
1105     // Copy in.
1106     llvm::SmallVector<mlir::Value> extents;
1107     bool copyUsingSlice = false;
1108     auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1109                                               copyUsingSlice);
1110     auto [allocmem, genTempCleanUp] =
1111         allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1112     genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1113                                   load.getMemref(), shapeOp, load.getSlice(),
1114                                   load);
1115     // Generate the reference for the access.
1116     rewriter.setInsertionPoint(op);
1117     auto coor = genCoorOp(
1118         rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, shapeOp,
1119         copyUsingSlice ? mlir::Value{} : load.getSlice(), access.getIndices(),
1120         load, access->hasAttr(factory::attrFortranArrayOffsets()));
1121     // Copy out.
1122     auto *storeOp = useMap.lookup(loadOp);
1123     auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1124     rewriter.setInsertionPoint(storeOp);
1125     // Copy out.
1126     genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(),
1127                                    allocmem, shapeOp, store.getSlice(), load);
1128     genTempCleanUp(rewriter);
1129     return coor;
1130   }
1131 
1132   /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
1133   /// temp and the LHS if the analysis found potential overlaps between the RHS
1134   /// and LHS arrays. The element copy generator must be provided in \p
1135   /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
1136   /// Returns the address of the LHS element inside the loop and the LHS
1137   /// ArrayLoad result.
1138   std::pair<mlir::Value, mlir::Value>
1139   materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
1140                         ArrayOp update,
1141                         const std::function<void(mlir::Value)> &assignElement,
1142                         mlir::Type lhsEltRefType) const {
1143     auto *op = update.getOperation();
1144     auto *loadOp = useMap.lookup(op);
1145     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1146     LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
1147     if (analysis.hasPotentialConflict(loadOp)) {
1148       // If there is a conflict between the arrays, then we copy the lhs array
1149       // to a temporary, update the temporary, and copy the temporary back to
1150       // the lhs array. This yields Fortran's copy-in copy-out array semantics.
1151       LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
1152       rewriter.setInsertionPoint(loadOp);
1153       // Copy in.
1154       llvm::SmallVector<mlir::Value> extents;
1155       bool copyUsingSlice = false;
1156       auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1157                                                 copyUsingSlice);
1158       auto [allocmem, genTempCleanUp] =
1159           allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1160 
1161       genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1162                                     load.getMemref(), shapeOp, load.getSlice(),
1163                                     load);
1164       rewriter.setInsertionPoint(op);
1165       auto coor = genCoorOp(
1166           rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
1167           shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
1168           update.getIndices(), load,
1169           update->hasAttr(factory::attrFortranArrayOffsets()));
1170       assignElement(coor);
1171       auto *storeOp = useMap.lookup(loadOp);
1172       auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1173       rewriter.setInsertionPoint(storeOp);
1174       // Copy out.
1175       genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter,
1176                                      store.getMemref(), allocmem, shapeOp,
1177                                      store.getSlice(), load);
1178       genTempCleanUp(rewriter);
1179       return {coor, load.getResult()};
1180     }
1181     // Otherwise, when there is no conflict (a possible loop-carried
1182     // dependence), the lhs array can be updated in place.
1183     LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
1184     rewriter.setInsertionPoint(op);
1185     auto coorTy = getEleTy(load.getType());
1186     auto coor =
1187         genCoorOp(rewriter, loc, coorTy, lhsEltRefType, load.getMemref(),
1188                   load.getShape(), load.getSlice(), update.getIndices(), load,
1189                   update->hasAttr(factory::attrFortranArrayOffsets()));
1190     assignElement(coor);
1191     return {coor, load.getResult()};
1192   }
1193 
1194 protected:
1195   const ArrayCopyAnalysis &analysis;
1196   const OperationUseMapT &useMap;
1197 };
1198 
1199 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
1200 public:
1201   explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
1202                                  const ArrayCopyAnalysis &a,
1203                                  const OperationUseMapT &m)
1204       : ArrayUpdateConversionBase{ctx, a, m} {}
1205 
1206   mlir::LogicalResult
1207   matchAndRewrite(ArrayUpdateOp update,
1208                   mlir::PatternRewriter &rewriter) const override {
1209     auto loc = update.getLoc();
1210     auto assignElement = [&](mlir::Value coor) {
1211       auto input = update.getMerge();
1212       if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) {
1213         emitFatalError(loc, "array_update on references not supported");
1214       } else {
1215         rewriter.create<fir::StoreOp>(loc, input, coor);
1216       }
1217     };
1218     auto lhsEltRefType = toRefType(update.getMerge().getType());
1219     auto [_, lhsLoadResult] = materializeAssignment(
1220         loc, rewriter, update, assignElement, lhsEltRefType);
1221     update.replaceAllUsesWith(lhsLoadResult);
1222     rewriter.replaceOp(update, lhsLoadResult);
1223     return mlir::success();
1224   }
1225 };
1226 
1227 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
1228 public:
1229   explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
1230                                  const ArrayCopyAnalysis &a,
1231                                  const OperationUseMapT &m)
1232       : ArrayUpdateConversionBase{ctx, a, m} {}
1233 
1234   mlir::LogicalResult
1235   matchAndRewrite(ArrayModifyOp modify,
1236                   mlir::PatternRewriter &rewriter) const override {
1237     auto loc = modify.getLoc();
1238     auto assignElement = [](mlir::Value) {
1239       // Assignment already materialized by lowering using lhs element address.
1240     };
1241     auto lhsEltRefType = modify.getResult(0).getType();
1242     auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
1243         loc, rewriter, modify, assignElement, lhsEltRefType);
1244     modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1245     rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1246     return mlir::success();
1247   }
1248 };
1249 
1250 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
1251 public:
1252   explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
1253                                 const OperationUseMapT &m)
1254       : OpRewritePattern{ctx}, useMap{m} {}
1255 
1256   mlir::LogicalResult
1257   matchAndRewrite(ArrayFetchOp fetch,
1258                   mlir::PatternRewriter &rewriter) const override {
1259     auto *op = fetch.getOperation();
1260     rewriter.setInsertionPoint(op);
1261     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1262     auto loc = fetch.getLoc();
1263     auto coor = genCoorOp(
1264         rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()),
1265         load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(),
1266         load, fetch->hasAttr(factory::attrFortranArrayOffsets()));
1267     if (isa_ref_type(fetch.getType()))
1268       rewriter.replaceOp(fetch, coor);
1269     else
1270       rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
1271     return mlir::success();
1272   }
1273 
1274 private:
1275   const OperationUseMapT &useMap;
1276 };
1277 
1278 /// As array_access op is like an array_fetch op, except that it does not imply
1279 /// a load op. (It operates in the reference domain.)
1280 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> {
1281 public:
1282   explicit ArrayAccessConversion(mlir::MLIRContext *ctx,
1283                                  const ArrayCopyAnalysis &a,
1284                                  const OperationUseMapT &m)
1285       : ArrayUpdateConversionBase{ctx, a, m} {}
1286 
1287   mlir::LogicalResult
1288   matchAndRewrite(ArrayAccessOp access,
1289                   mlir::PatternRewriter &rewriter) const override {
1290     auto *op = access.getOperation();
1291     auto loc = access.getLoc();
1292     if (analysis.inAmendAccessSet(op)) {
1293       // This array_access is associated with an array_amend and there is a
1294       // conflict. Make a copy to store into.
1295       auto result = referenceToClone(loc, rewriter, access);
1296       access.replaceAllUsesWith(result);
1297       rewriter.replaceOp(access, result);
1298       return mlir::success();
1299     }
1300     rewriter.setInsertionPoint(op);
1301     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1302     auto coor = genCoorOp(
1303         rewriter, loc, getEleTy(load.getType()), toRefType(access.getType()),
1304         load.getMemref(), load.getShape(), load.getSlice(), access.getIndices(),
1305         load, access->hasAttr(factory::attrFortranArrayOffsets()));
1306     rewriter.replaceOp(access, coor);
1307     return mlir::success();
1308   }
1309 };
1310 
1311 /// An array_amend op is a marker to record which array access is being used to
1312 /// update an array value. After this pass runs, an array_amend has no
1313 /// semantics. We rewrite these to undefined values here to remove them while
1314 /// preserving SSA form.
1315 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> {
1316 public:
1317   explicit ArrayAmendConversion(mlir::MLIRContext *ctx)
1318       : OpRewritePattern{ctx} {}
1319 
1320   mlir::LogicalResult
1321   matchAndRewrite(ArrayAmendOp amend,
1322                   mlir::PatternRewriter &rewriter) const override {
1323     auto *op = amend.getOperation();
1324     rewriter.setInsertionPoint(op);
1325     auto loc = amend.getLoc();
1326     auto undef = rewriter.create<UndefOp>(loc, amend.getType());
1327     rewriter.replaceOp(amend, undef.getResult());
1328     return mlir::success();
1329   }
1330 };
1331 
1332 class ArrayValueCopyConverter
1333     : public fir::impl::ArrayValueCopyBase<ArrayValueCopyConverter> {
1334 public:
1335   void runOnOperation() override {
1336     auto func = getOperation();
1337     LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
1338                             << func.getName() << "'\n");
1339     auto *context = &getContext();
1340 
1341     // Perform the conflict analysis.
1342     const auto &analysis = getAnalysis<ArrayCopyAnalysis>();
1343     const auto &useMap = analysis.getUseMap();
1344 
1345     mlir::RewritePatternSet patterns1(context);
1346     patterns1.insert<ArrayFetchConversion>(context, useMap);
1347     patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
1348     patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
1349     patterns1.insert<ArrayAccessConversion>(context, analysis, useMap);
1350     patterns1.insert<ArrayAmendConversion>(context);
1351     mlir::ConversionTarget target(*context);
1352     target
1353         .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
1354                          mlir::arith::ArithDialect, mlir::func::FuncDialect>();
1355     target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp,
1356                         ArrayUpdateOp, ArrayModifyOp>();
1357     // Rewrite the array fetch and array update ops.
1358     if (mlir::failed(
1359             mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
1360       mlir::emitError(mlir::UnknownLoc::get(context),
1361                       "failure in array-value-copy pass, phase 1");
1362       signalPassFailure();
1363     }
1364 
1365     mlir::RewritePatternSet patterns2(context);
1366     patterns2.insert<ArrayLoadConversion>(context);
1367     patterns2.insert<ArrayMergeStoreConversion>(context);
1368     target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
1369     if (mlir::failed(
1370             mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
1371       mlir::emitError(mlir::UnknownLoc::get(context),
1372                       "failure in array-value-copy pass, phase 2");
1373       signalPassFailure();
1374     }
1375   }
1376 };
1377 } // namespace
1378 
1379 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() {
1380   return std::make_unique<ArrayValueCopyConverter>();
1381 }
1382