xref: /llvm-project/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp (revision 8b68da2c7d97ef0e2dde4d9eb867020bde2f5fe2)
1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/Array.h"
11 #include "flang/Optimizer/Builder/BoxValue.h"
12 #include "flang/Optimizer/Builder/FIRBuilder.h"
13 #include "flang/Optimizer/Builder/Factory.h"
14 #include "flang/Optimizer/Builder/Runtime/Derived.h"
15 #include "flang/Optimizer/Builder/Todo.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
18 #include "flang/Optimizer/Support/FIRContext.h"
19 #include "flang/Optimizer/Transforms/Passes.h"
20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "flang-array-value-copy"
26 
27 using namespace fir;
28 using namespace mlir;
29 
30 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
31 
32 namespace {
33 
34 /// Array copy analysis.
35 /// Perform an interference analysis between array values.
36 ///
37 /// Lowering will generate a sequence of the following form.
38 /// ```mlir
39 ///   %a_1 = fir.array_load %array_1(%shape) : ...
40 ///   ...
41 ///   %a_j = fir.array_load %array_j(%shape) : ...
42 ///   ...
43 ///   %a_n = fir.array_load %array_n(%shape) : ...
44 ///     ...
45 ///     %v_i = fir.array_fetch %a_i, ...
46 ///     %a_j1 = fir.array_update %a_j, ...
47 ///     ...
48 ///   fir.array_merge_store %a_j, %a_jn to %array_j : ...
49 /// ```
50 ///
51 /// The analysis is to determine if there are any conflicts. A conflict is when
52 /// one the following cases occurs.
53 ///
54 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
55 /// loaded from the same array memory reference (array_j) but with a different
56 /// shape as the other array values a_i, where i != j. [Possible overlapping
57 /// arrays.]
58 ///
59 /// 2. There is either an array_fetch or array_update of a_j with a different
60 /// set of index values. [Possible loop-carried dependence.]
61 ///
62 /// If none of the array values overlap in storage and the accesses are not
63 /// loop-carried, then the arrays are conflict-free and no copies are required.
64 class ArrayCopyAnalysis {
65 public:
66   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis)
67 
68   using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
69   using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
70   using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>;
71   using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>;
72 
73   ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); }
74 
75   mlir::Operation *getOperation() const { return operation; }
76 
77   /// Return true iff the `array_merge_store` has potential conflicts.
78   bool hasPotentialConflict(mlir::Operation *op) const {
79     LLVM_DEBUG(llvm::dbgs()
80                << "looking for a conflict on " << *op
81                << " and the set has a total of " << conflicts.size() << '\n');
82     return conflicts.contains(op);
83   }
84 
85   /// Return the use map.
86   /// The use map maps array access, amend, fetch and update operations back to
87   /// the array load that is the original source of the array value.
88   /// It maps an array_load to an array_merge_store, if and only if the loaded
89   /// array value has pending modifications to be merged.
90   const OperationUseMapT &getUseMap() const { return useMap; }
91 
92   /// Return the set of array_access ops directly associated with array_amend
93   /// ops.
94   bool inAmendAccessSet(mlir::Operation *op) const {
95     return amendAccesses.count(op);
96   }
97 
98   /// For ArrayLoad `load`, return the transitive set of all OpOperands.
99   UseSetT getLoadUseSet(mlir::Operation *load) const {
100     assert(loadMapSets.count(load) && "analysis missed an array load?");
101     return loadMapSets.lookup(load);
102   }
103 
104   void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions,
105                      ArrayLoadOp load);
106 
107 private:
108   void construct(mlir::Operation *topLevelOp);
109 
110   mlir::Operation *operation; // operation that analysis ran upon
111   ConflictSetT conflicts;     // set of conflicts (loads and merge stores)
112   OperationUseMapT useMap;
113   LoadMapSetsT loadMapSets;
114   // Set of array_access ops associated with array_amend ops.
115   AmendAccessSetT amendAccesses;
116 };
117 } // namespace
118 
119 namespace {
120 /// Helper class to collect all array operations that produced an array value.
121 class ReachCollector {
122 public:
123   ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
124                  mlir::Region *loopRegion)
125       : reach{reach}, loopRegion{loopRegion} {}
126 
127   void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) {
128     if (range.empty()) {
129       collectArrayMentionFrom(op, mlir::Value{});
130       return;
131     }
132     for (mlir::Value v : range)
133       collectArrayMentionFrom(v);
134   }
135 
136   // Collect all the array_access ops in `block`. This recursively looks into
137   // blocks in ops with regions.
138   // FIXME: This is temporarily relying on the array_amend appearing in a
139   // do_loop Region.  This phase ordering assumption can be eliminated by using
140   // dominance information to find the array_access ops or by scanning the
141   // transitive closure of the amending array_access's users and the defs that
142   // reach them.
143   void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result,
144                        mlir::Block *block) {
145     for (auto &op : *block) {
146       if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) {
147         LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n');
148         result.push_back(access);
149         continue;
150       }
151       for (auto &region : op.getRegions())
152         for (auto &bb : region.getBlocks())
153           collectAccesses(result, &bb);
154     }
155   }
156 
157   void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) {
158     // `val` is defined by an Op, process the defining Op.
159     // If `val` is defined by a region containing Op, we want to drill down
160     // and through that Op's region(s).
161     LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
162     auto popFn = [&](auto rop) {
163       assert(val && "op must have a result value");
164       auto resNum = val.cast<mlir::OpResult>().getResultNumber();
165       llvm::SmallVector<mlir::Value> results;
166       rop.resultToSourceOps(results, resNum);
167       for (auto u : results)
168         collectArrayMentionFrom(u);
169     };
170     if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) {
171       popFn(rop);
172       return;
173     }
174     if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) {
175       popFn(rop);
176       return;
177     }
178     if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
179       popFn(rop);
180       return;
181     }
182     if (auto box = mlir::dyn_cast<EmboxOp>(op)) {
183       for (auto *user : box.getMemref().getUsers())
184         if (user != op)
185           collectArrayMentionFrom(user, user->getResults());
186       return;
187     }
188     if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
189       if (opIsInsideLoops(mergeStore))
190         collectArrayMentionFrom(mergeStore.getSequence());
191       return;
192     }
193 
194     if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
195       // Look for any stores inside the loops, and collect an array operation
196       // that produced the value being stored to it.
197       for (auto *user : op->getUsers())
198         if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
199           if (opIsInsideLoops(store))
200             collectArrayMentionFrom(store.getValue());
201       return;
202     }
203 
204     // Scan the uses of amend's memref
205     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) {
206       reach.push_back(op);
207       llvm::SmallVector<ArrayAccessOp> accesses;
208       collectAccesses(accesses, op->getBlock());
209       for (auto access : accesses)
210         collectArrayMentionFrom(access.getResult());
211     }
212 
213     // Otherwise, Op does not contain a region so just chase its operands.
214     if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp,
215                   ArrayFetchOp>(op)) {
216       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
217       reach.push_back(op);
218     }
219 
220     // Include all array_access ops using an array_load.
221     if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op))
222       for (auto *user : arrLd.getResult().getUsers())
223         if (mlir::isa<ArrayAccessOp>(user)) {
224           LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n");
225           reach.push_back(user);
226         }
227 
228     // Array modify assignment is performed on the result. So the analysis must
229     // look at the what is done with the result.
230     if (mlir::isa<ArrayModifyOp>(op))
231       for (auto *user : op->getResult(0).getUsers())
232         followUsers(user);
233 
234     if (mlir::isa<fir::CallOp>(op)) {
235       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
236       reach.push_back(op);
237     }
238 
239     for (auto u : op->getOperands())
240       collectArrayMentionFrom(u);
241   }
242 
243   void collectArrayMentionFrom(mlir::BlockArgument ba) {
244     auto *parent = ba.getOwner()->getParentOp();
245     // If inside an Op holding a region, the block argument corresponds to an
246     // argument passed to the containing Op.
247     auto popFn = [&](auto rop) {
248       collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
249     };
250     if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
251       popFn(rop);
252       return;
253     }
254     if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
255       popFn(rop);
256       return;
257     }
258     // Otherwise, a block argument is provided via the pred blocks.
259     for (auto *pred : ba.getOwner()->getPredecessors()) {
260       auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
261       collectArrayMentionFrom(u);
262     }
263   }
264 
265   // Recursively trace operands to find all array operations relating to the
266   // values merged.
267   void collectArrayMentionFrom(mlir::Value val) {
268     if (!val || visited.contains(val))
269       return;
270     visited.insert(val);
271 
272     // Process a block argument.
273     if (auto ba = val.dyn_cast<mlir::BlockArgument>()) {
274       collectArrayMentionFrom(ba);
275       return;
276     }
277 
278     // Process an Op.
279     if (auto *op = val.getDefiningOp()) {
280       collectArrayMentionFrom(op, val);
281       return;
282     }
283 
284     emitFatalError(val.getLoc(), "unhandled value");
285   }
286 
287   /// Return all ops that produce the array value that is stored into the
288   /// `array_merge_store`.
289   static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
290                              mlir::Value seq) {
291     reach.clear();
292     mlir::Region *loopRegion = nullptr;
293     if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp()))
294       loopRegion = &doLoop->getRegion(0);
295     ReachCollector collector(reach, loopRegion);
296     collector.collectArrayMentionFrom(seq);
297   }
298 
299 private:
300   /// Is \op inside the loop nest region ?
301   /// FIXME: replace this structural dependence with graph properties.
302   bool opIsInsideLoops(mlir::Operation *op) const {
303     auto *region = op->getParentRegion();
304     while (region) {
305       if (region == loopRegion)
306         return true;
307       region = region->getParentRegion();
308     }
309     return false;
310   }
311 
312   /// Recursively trace the use of an operation results, calling
313   /// collectArrayMentionFrom on the direct and indirect user operands.
314   void followUsers(mlir::Operation *op) {
315     for (auto userOperand : op->getOperands())
316       collectArrayMentionFrom(userOperand);
317     // Go through potential converts/coordinate_op.
318     for (auto indirectUser : op->getUsers())
319       followUsers(indirectUser);
320   }
321 
322   llvm::SmallVectorImpl<mlir::Operation *> &reach;
323   llvm::SmallPtrSet<mlir::Value, 16> visited;
324   /// Region of the loops nest that produced the array value.
325   mlir::Region *loopRegion;
326 };
327 } // namespace
328 
329 /// Find all the array operations that access the array value that is loaded by
330 /// the array load operation, `load`.
331 void ArrayCopyAnalysis::arrayMentions(
332     llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) {
333   mentions.clear();
334   auto lmIter = loadMapSets.find(load);
335   if (lmIter != loadMapSets.end()) {
336     for (auto *opnd : lmIter->second) {
337       auto *owner = opnd->getOwner();
338       if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
339                     ArrayModifyOp>(owner))
340         mentions.push_back(owner);
341     }
342     return;
343   }
344 
345   UseSetT visited;
346   llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
347 
348   auto appendToQueue = [&](mlir::Value val) {
349     for (auto &use : val.getUses())
350       if (!visited.count(&use)) {
351         visited.insert(&use);
352         queue.push_back(&use);
353       }
354   };
355 
356   // Build the set of uses of `original`.
357   // let USES = { uses of original fir.load }
358   appendToQueue(load);
359 
360   // Process the worklist until done.
361   while (!queue.empty()) {
362     mlir::OpOperand *operand = queue.pop_back_val();
363     mlir::Operation *owner = operand->getOwner();
364     if (!owner)
365       continue;
366     auto structuredLoop = [&](auto ro) {
367       if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
368         int64_t arg = blockArg.getArgNumber();
369         mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1);
370         appendToQueue(output);
371         appendToQueue(blockArg);
372       }
373     };
374     // TODO: this need to be updated to use the control-flow interface.
375     auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
376       if (operands.empty())
377         return;
378 
379       // Check if this operand is within the range.
380       unsigned operandIndex = operand->getOperandNumber();
381       unsigned operandsStart = operands.getBeginOperandIndex();
382       if (operandIndex < operandsStart ||
383           operandIndex >= (operandsStart + operands.size()))
384         return;
385 
386       // Index the successor.
387       unsigned argIndex = operandIndex - operandsStart;
388       appendToQueue(dest->getArgument(argIndex));
389     };
390     // Thread uses into structured loop bodies and return value uses.
391     if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
392       structuredLoop(ro);
393     } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
394       structuredLoop(ro);
395     } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
396       // Thread any uses of fir.if that return the marked array value.
397       mlir::Operation *parent = rs->getParentRegion()->getParentOp();
398       if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent))
399         appendToQueue(ifOp.getResult(operand->getOperandNumber()));
400     } else if (mlir::isa<ArrayFetchOp>(owner)) {
401       // Keep track of array value fetches.
402       LLVM_DEBUG(llvm::dbgs()
403                  << "add fetch {" << *owner << "} to array value set\n");
404       mentions.push_back(owner);
405     } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
406       // Keep track of array value updates and thread the return value uses.
407       LLVM_DEBUG(llvm::dbgs()
408                  << "add update {" << *owner << "} to array value set\n");
409       mentions.push_back(owner);
410       appendToQueue(update.getResult());
411     } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
412       // Keep track of array value modification and thread the return value
413       // uses.
414       LLVM_DEBUG(llvm::dbgs()
415                  << "add modify {" << *owner << "} to array value set\n");
416       mentions.push_back(owner);
417       appendToQueue(update.getResult(1));
418     } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) {
419       mentions.push_back(owner);
420     } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) {
421       mentions.push_back(owner);
422       appendToQueue(amend.getResult());
423     } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
424       branchOp(br.getDest(), br.getDestOperands());
425     } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
426       branchOp(br.getTrueDest(), br.getTrueOperands());
427       branchOp(br.getFalseDest(), br.getFalseOperands());
428     } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
429       // do nothing
430     } else {
431       llvm::report_fatal_error("array value reached unexpected op");
432     }
433   }
434   loadMapSets.insert({load, visited});
435 }
436 
437 static bool hasPointerType(mlir::Type type) {
438   if (auto boxTy = type.dyn_cast<BoxType>())
439     type = boxTy.getEleTy();
440   return type.isa<fir::PointerType>();
441 }
442 
443 // This is a NF performance hack. It makes a simple test that the slices of the
444 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive.
445 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) {
446   // If the same array_load, then no further testing is warranted.
447   if (ld.getResult() == st.getOriginal())
448     return false;
449 
450   auto getSliceOp = [](mlir::Value val) -> SliceOp {
451     if (!val)
452       return {};
453     auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp());
454     if (!sliceOp)
455       return {};
456     return sliceOp;
457   };
458 
459   auto ldSlice = getSliceOp(ld.getSlice());
460   auto stSlice = getSliceOp(st.getSlice());
461   if (!ldSlice || !stSlice)
462     return false;
463 
464   // Resign on subobject slices.
465   if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() ||
466       !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty())
467     return false;
468 
469   // Crudely test that the two slices do not overlap by looking for the
470   // following general condition. If the slices look like (i:j) and (j+1:k) then
471   // these ranges do not overlap. The addend must be a constant.
472   auto ldTriples = ldSlice.getTriples();
473   auto stTriples = stSlice.getTriples();
474   const auto size = ldTriples.size();
475   if (size != stTriples.size())
476     return false;
477 
478   auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
479     auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
480       auto *op = v.getDefiningOp();
481       while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op))
482         op = conv.getValue().getDefiningOp();
483       return op;
484     };
485 
486     auto isPositiveConstant = [](mlir::Value v) -> bool {
487       if (auto conOp =
488               mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
489         if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>())
490           return iattr.getInt() > 0;
491       return false;
492     };
493 
494     auto *op1 = removeConvert(v1);
495     auto *op2 = removeConvert(v2);
496     if (!op1 || !op2)
497       return false;
498     if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
499       if ((addi.getLhs().getDefiningOp() == op1 &&
500            isPositiveConstant(addi.getRhs())) ||
501           (addi.getRhs().getDefiningOp() == op1 &&
502            isPositiveConstant(addi.getLhs())))
503         return true;
504     if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
505       if (subi.getLhs().getDefiningOp() == op2 &&
506           isPositiveConstant(subi.getRhs()))
507         return true;
508     return false;
509   };
510 
511   for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) {
512     // If both are loop invariant, skip to the next triple.
513     if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) &&
514         mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) {
515       // Unless either is a vector index, then be conservative.
516       if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) ||
517           mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp()))
518         return false;
519       continue;
520     }
521     // If identical, skip to the next triple.
522     if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] &&
523         ldTriples[i + 2] == stTriples[i + 2])
524       continue;
525     // If ubound and lbound are the same with a constant offset, skip to the
526     // next triple.
527     if (displacedByConstant(ldTriples[i + 1], stTriples[i]) ||
528         displacedByConstant(stTriples[i + 1], ldTriples[i]))
529       continue;
530     return false;
531   }
532   LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld
533                           << " and " << st << ", which is not a conflict\n");
534   return true;
535 }
536 
537 /// Is there a conflict between the array value that was updated and to be
538 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
539 /// the updated value?
540 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
541                            ArrayMergeStoreOp st) {
542   mlir::Value load;
543   mlir::Value addr = st.getMemref();
544   const bool storeHasPointerType = hasPointerType(addr.getType());
545   for (auto *op : reach)
546     if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) {
547       mlir::Type ldTy = ld.getMemref().getType();
548       if (ld.getMemref() == addr) {
549         if (mutuallyExclusiveSliceRange(ld, st))
550           continue;
551         if (ld.getResult() != st.getOriginal())
552           return true;
553         if (load) {
554           // TODO: extend this to allow checking if the first `load` and this
555           // `ld` are mutually exclusive accesses but not identical.
556           return true;
557         }
558         load = ld;
559       } else if ((hasPointerType(ldTy) || storeHasPointerType)) {
560         // TODO: Use target attribute to restrict this case further.
561         // TODO: Check if types can also allow ruling out some cases. For now,
562         // the fact that equivalences is using pointer attribute to enforce
563         // aliasing is preventing any attempt to do so, and in general, it may
564         // be wrong to use this if any of the types is a complex or a derived
565         // for which it is possible to create a pointer to a part with a
566         // different type than the whole, although this deserve some more
567         // investigation because existing compiler behavior seem to diverge
568         // here.
569         return true;
570       }
571     }
572   return false;
573 }
574 
575 /// Is there an access vector conflict on the array being merged into? If the
576 /// access vectors diverge, then assume that there are potentially overlapping
577 /// loop-carried references.
578 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) {
579   if (mentions.size() < 2)
580     return false;
581   llvm::SmallVector<mlir::Value> indices;
582   LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size()
583                           << " mentions on the list\n");
584   bool valSeen = false;
585   bool refSeen = false;
586   for (auto *op : mentions) {
587     llvm::SmallVector<mlir::Value> compareVector;
588     if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
589       valSeen = true;
590       if (indices.empty()) {
591         indices = u.getIndices();
592         continue;
593       }
594       compareVector = u.getIndices();
595     } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
596       valSeen = true;
597       if (indices.empty()) {
598         indices = f.getIndices();
599         continue;
600       }
601       compareVector = f.getIndices();
602     } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
603       valSeen = true;
604       if (indices.empty()) {
605         indices = f.getIndices();
606         continue;
607       }
608       compareVector = f.getIndices();
609     } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) {
610       refSeen = true;
611       if (indices.empty()) {
612         indices = f.getIndices();
613         continue;
614       }
615       compareVector = f.getIndices();
616     } else if (mlir::isa<ArrayAmendOp>(op)) {
617       refSeen = true;
618       continue;
619     } else {
620       mlir::emitError(op->getLoc(), "unexpected operation in analysis");
621     }
622     if (compareVector.size() != indices.size() ||
623         llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) {
624           return std::get<0>(pair) != std::get<1>(pair);
625         }))
626       return true;
627     LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
628   }
629   return valSeen && refSeen;
630 }
631 
632 /// With element-by-reference semantics, an amended array with more than once
633 /// access to the same loaded array are conservatively considered a conflict.
634 /// Note: the array copy can still be eliminated in subsequent optimizations.
635 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) {
636   LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size()
637                           << '\n');
638   if (mentions.size() < 3)
639     return false;
640   unsigned amendCount = 0;
641   unsigned accessCount = 0;
642   for (auto *op : mentions) {
643     if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) {
644       LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n");
645       return true;
646     }
647     if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) {
648       LLVM_DEBUG(llvm::dbgs()
649                  << "conflict: multiple accesses of array value\n");
650       return true;
651     }
652     if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) {
653       LLVM_DEBUG(llvm::dbgs()
654                  << "conflict: array value has both uses by-value and uses "
655                     "by-reference. conservative assumption.\n");
656       return true;
657     }
658   }
659   return false;
660 }
661 
662 static mlir::Operation *
663 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) {
664   for (auto *op : mentions)
665     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op))
666       return amend.getMemref().getDefiningOp();
667   return {};
668 }
669 
670 // Are any conflicts present? The conflicts detected here are described above.
671 static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
672                              llvm::ArrayRef<mlir::Operation *> mentions,
673                              ArrayMergeStoreOp st) {
674   return conflictOnLoad(reach, st) || conflictOnMerge(mentions);
675 }
676 
677 // Assume that any call to a function that uses host-associations will be
678 // modifying the output array.
679 static bool
680 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) {
681   return llvm::any_of(reaches, [](mlir::Operation *op) {
682     if (auto call = mlir::dyn_cast<fir::CallOp>(op))
683       if (auto callee =
684               call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) {
685         auto module = op->getParentOfType<mlir::ModuleOp>();
686         return hasHostAssociationArgument(
687             module.lookupSymbol<mlir::func::FuncOp>(callee));
688       }
689     return false;
690   });
691 }
692 
693 /// Constructor of the array copy analysis.
694 /// This performs the analysis and saves the intermediate results.
695 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) {
696   topLevelOp->walk([&](Operation *op) {
697     if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
698       llvm::SmallVector<mlir::Operation *> values;
699       ReachCollector::reachingValues(values, st.getSequence());
700       bool callConflict = conservativeCallConflict(values);
701       llvm::SmallVector<mlir::Operation *> mentions;
702       arrayMentions(mentions,
703                     mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp()));
704       bool conflict = conflictDetected(values, mentions, st);
705       bool refConflict = conflictOnReference(mentions);
706       if (callConflict || conflict || refConflict) {
707         LLVM_DEBUG(llvm::dbgs()
708                    << "CONFLICT: copies required for " << st << '\n'
709                    << "   adding conflicts on: " << op << " and "
710                    << st.getOriginal() << '\n');
711         conflicts.insert(op);
712         conflicts.insert(st.getOriginal().getDefiningOp());
713         if (auto *access = amendingAccess(mentions))
714           amendAccesses.insert(access);
715       }
716       auto *ld = st.getOriginal().getDefiningOp();
717       LLVM_DEBUG(llvm::dbgs()
718                  << "map: adding {" << *ld << " -> " << st << "}\n");
719       useMap.insert({ld, op});
720     } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
721       llvm::SmallVector<mlir::Operation *> mentions;
722       arrayMentions(mentions, load);
723       LLVM_DEBUG(llvm::dbgs() << "process load: " << load
724                               << ", mentions: " << mentions.size() << '\n');
725       for (auto *acc : mentions) {
726         LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n');
727         if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
728                       ArrayModifyOp>(acc)) {
729           if (useMap.count(acc)) {
730             mlir::emitError(
731                 load.getLoc(),
732                 "The parallel semantics of multiple array_merge_stores per "
733                 "array_load are not supported.");
734             continue;
735           }
736           LLVM_DEBUG(llvm::dbgs()
737                      << "map: adding {" << *acc << "} -> {" << load << "}\n");
738           useMap.insert({acc, op});
739         }
740       }
741     }
742   });
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // Conversions for converting out of array value form.
747 //===----------------------------------------------------------------------===//
748 
749 namespace {
750 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
751 public:
752   using OpRewritePattern::OpRewritePattern;
753 
754   mlir::LogicalResult
755   matchAndRewrite(ArrayLoadOp load,
756                   mlir::PatternRewriter &rewriter) const override {
757     LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
758     rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
759     return mlir::success();
760   }
761 };
762 
763 class ArrayMergeStoreConversion
764     : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
765 public:
766   using OpRewritePattern::OpRewritePattern;
767 
768   mlir::LogicalResult
769   matchAndRewrite(ArrayMergeStoreOp store,
770                   mlir::PatternRewriter &rewriter) const override {
771     LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
772     rewriter.eraseOp(store);
773     return mlir::success();
774   }
775 };
776 } // namespace
777 
778 static mlir::Type getEleTy(mlir::Type ty) {
779   auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty));
780   // FIXME: keep ptr/heap/ref information.
781   return ReferenceType::get(eleTy);
782 }
783 
784 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
785 static bool getAdjustedExtents(mlir::Location loc,
786                                mlir::PatternRewriter &rewriter,
787                                ArrayLoadOp arrLoad,
788                                llvm::SmallVectorImpl<mlir::Value> &result,
789                                mlir::Value shape) {
790   bool copyUsingSlice = false;
791   auto *shapeOp = shape.getDefiningOp();
792   if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) {
793     auto e = s.getExtents();
794     result.insert(result.end(), e.begin(), e.end());
795   } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) {
796     auto e = s.getExtents();
797     result.insert(result.end(), e.begin(), e.end());
798   } else {
799     emitFatalError(loc, "not a fir.shape/fir.shape_shift op");
800   }
801   auto idxTy = rewriter.getIndexType();
802   if (factory::isAssumedSize(result)) {
803     // Use slice information to compute the extent of the column.
804     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
805     mlir::Value size = one;
806     if (mlir::Value sliceArg = arrLoad.getSlice()) {
807       if (auto sliceOp =
808               mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) {
809         auto triples = sliceOp.getTriples();
810         const std::size_t tripleSize = triples.size();
811         auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
812         FirOpBuilder builder(rewriter, getKindMapping(module));
813         size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3],
814                                             triples[tripleSize - 2],
815                                             triples[tripleSize - 1], idxTy);
816         copyUsingSlice = true;
817       }
818     }
819     result[result.size() - 1] = size;
820   }
821   return copyUsingSlice;
822 }
823 
824 /// Place the extents of the array load, \p arrLoad, into \p result and
825 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is
826 /// loading a `!fir.box`, code will be generated to read the extents from the
827 /// boxed value, and the retunred shape Op will be built with the extents read
828 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or
829 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true
830 /// if slicing of the output array is to be done in the copy-in/copy-out rather
831 /// than in the elemental computation step.
832 static mlir::Value getOrReadExtentsAndShapeOp(
833     mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad,
834     llvm::SmallVectorImpl<mlir::Value> &result, bool &copyUsingSlice) {
835   assert(result.empty());
836   if (arrLoad->hasAttr(fir::getOptionalAttrName()))
837     fir::emitFatalError(
838         loc, "shapes from array load of OPTIONAL arrays must not be used");
839   if (auto boxTy = arrLoad.getMemref().getType().dyn_cast<BoxType>()) {
840     auto rank =
841         dyn_cast_ptrOrBoxEleTy(boxTy).cast<SequenceType>().getDimension();
842     auto idxTy = rewriter.getIndexType();
843     for (decltype(rank) dim = 0; dim < rank; ++dim) {
844       auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim);
845       auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy,
846                                                 arrLoad.getMemref(), dimVal);
847       result.emplace_back(dimInfo.getResult(1));
848     }
849     if (!arrLoad.getShape()) {
850       auto shapeType = ShapeType::get(rewriter.getContext(), rank);
851       return rewriter.create<ShapeOp>(loc, shapeType, result);
852     }
853     auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>();
854     auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank);
855     llvm::SmallVector<mlir::Value> shapeShiftOperands;
856     for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) {
857       shapeShiftOperands.push_back(lb);
858       shapeShiftOperands.push_back(extent);
859     }
860     return rewriter.create<ShapeShiftOp>(loc, shapeShiftType,
861                                          shapeShiftOperands);
862   }
863   copyUsingSlice =
864       getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape());
865   return arrLoad.getShape();
866 }
867 
868 static mlir::Type toRefType(mlir::Type ty) {
869   if (fir::isa_ref_type(ty))
870     return ty;
871   return fir::ReferenceType::get(ty);
872 }
873 
874 static llvm::SmallVector<mlir::Value>
875 getTypeParamsIfRawData(mlir::Location loc, FirOpBuilder &builder,
876                        ArrayLoadOp arrLoad, mlir::Type ty) {
877   if (ty.isa<BoxType>())
878     return {};
879   return fir::factory::getTypeParams(loc, builder, arrLoad);
880 }
881 
882 static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter,
883                              mlir::Location loc, mlir::Type eleTy,
884                              mlir::Type resTy, mlir::Value alloc,
885                              mlir::Value shape, mlir::Value slice,
886                              mlir::ValueRange indices, ArrayLoadOp load,
887                              bool skipOrig = false) {
888   llvm::SmallVector<mlir::Value> originated;
889   if (skipOrig)
890     originated.assign(indices.begin(), indices.end());
891   else
892     originated = factory::originateIndices(loc, rewriter, alloc.getType(),
893                                            shape, indices);
894   auto seqTy = dyn_cast_ptrOrBoxEleTy(alloc.getType());
895   assert(seqTy && seqTy.isa<SequenceType>());
896   const auto dimension = seqTy.cast<SequenceType>().getDimension();
897   auto module = load->getParentOfType<mlir::ModuleOp>();
898   FirOpBuilder builder(rewriter, getKindMapping(module));
899   auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType());
900   mlir::Value result = rewriter.create<ArrayCoorOp>(
901       loc, eleTy, alloc, shape, slice,
902       llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
903       typeparams);
904   if (dimension < originated.size())
905     result = rewriter.create<fir::CoordinateOp>(
906         loc, resTy, result,
907         llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
908   return result;
909 }
910 
911 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder,
912                                    ArrayLoadOp load, CharacterType charTy) {
913   auto charLenTy = builder.getCharacterLengthType();
914   if (charTy.hasDynamicLen()) {
915     if (load.getMemref().getType().isa<BoxType>()) {
916       // The loaded array is an emboxed value. Get the CHARACTER length from
917       // the box value.
918       auto eleSzInBytes =
919           builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref());
920       auto kindSize =
921           builder.getKindMap().getCharacterBitsize(charTy.getFKind());
922       auto kindByteSize =
923           builder.createIntegerConstant(loc, charLenTy, kindSize / 8);
924       return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes,
925                                                   kindByteSize);
926     }
927     // The loaded array is a (set of) unboxed values. If the CHARACTER's
928     // length is not a constant, it must be provided as a type parameter to
929     // the array_load.
930     auto typeparams = load.getTypeparams();
931     assert(typeparams.size() > 0 && "expected type parameters on array_load");
932     return typeparams.back();
933   }
934   // The typical case: the length of the CHARACTER is a compile-time
935   // constant that is encoded in the type information.
936   return builder.createIntegerConstant(loc, charLenTy, charTy.getLen());
937 }
938 /// Generate a shallow array copy. This is used for both copy-in and copy-out.
939 template <bool CopyIn>
940 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
941                   mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
942                   mlir::Value sliceOp, ArrayLoadOp arrLoad) {
943   auto insPt = rewriter.saveInsertionPoint();
944   llvm::SmallVector<mlir::Value> indices;
945   llvm::SmallVector<mlir::Value> extents;
946   bool copyUsingSlice =
947       getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp);
948   auto idxTy = rewriter.getIndexType();
949   // Build loop nest from column to row.
950   for (auto sh : llvm::reverse(extents)) {
951     auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh);
952     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
953     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
954     auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one);
955     auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one);
956     rewriter.setInsertionPointToStart(loop.getBody());
957     indices.push_back(loop.getInductionVar());
958   }
959   // Reverse the indices so they are in column-major order.
960   std::reverse(indices.begin(), indices.end());
961   auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
962   FirOpBuilder builder(rewriter, getKindMapping(module));
963   auto fromAddr = rewriter.create<ArrayCoorOp>(
964       loc, getEleTy(src.getType()), src, shapeOp,
965       CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
966       factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices),
967       getTypeParamsIfRawData(loc, builder, arrLoad, src.getType()));
968   auto toAddr = rewriter.create<ArrayCoorOp>(
969       loc, getEleTy(dst.getType()), dst, shapeOp,
970       !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
971       factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices),
972       getTypeParamsIfRawData(loc, builder, arrLoad, src.getType()));
973   auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType()));
974   // Copy from (to) object to (from) temp copy of same object.
975   if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
976     auto len = getCharacterLen(loc, builder, arrLoad, charTy);
977     CharBoxValue toChar(toAddr, len);
978     CharBoxValue fromChar(fromAddr, len);
979     factory::genScalarAssignment(builder, loc, toChar, fromChar);
980   } else {
981     if (hasDynamicSize(eleTy))
982       TODO(loc, "copy element of dynamic size");
983     factory::genScalarAssignment(builder, loc, toAddr, fromAddr);
984   }
985   rewriter.restoreInsertionPoint(insPt);
986 }
987 
988 /// The array load may be either a boxed or unboxed value. If the value is
989 /// boxed, we read the type parameters from the boxed value.
990 static llvm::SmallVector<mlir::Value>
991 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter,
992                            ArrayLoadOp load) {
993   if (load.getTypeparams().empty()) {
994     auto eleTy =
995         unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType()));
996     if (hasDynamicSize(eleTy)) {
997       if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
998         assert(load.getMemref().getType().isa<BoxType>());
999         auto module = load->getParentOfType<mlir::ModuleOp>();
1000         FirOpBuilder builder(rewriter, getKindMapping(module));
1001         return {getCharacterLen(loc, builder, load, charTy)};
1002       }
1003       TODO(loc, "unhandled dynamic type parameters");
1004     }
1005     return {};
1006   }
1007   return load.getTypeparams();
1008 }
1009 
1010 static llvm::SmallVector<mlir::Value>
1011 findNonconstantExtents(mlir::Type memrefTy,
1012                        llvm::ArrayRef<mlir::Value> extents) {
1013   llvm::SmallVector<mlir::Value> nce;
1014   auto arrTy = unwrapPassByRefType(memrefTy);
1015   auto seqTy = arrTy.cast<SequenceType>();
1016   for (auto [s, x] : llvm::zip(seqTy.getShape(), extents))
1017     if (s == SequenceType::getUnknownExtent())
1018       nce.emplace_back(x);
1019   if (extents.size() > seqTy.getShape().size())
1020     for (auto x : extents.drop_front(seqTy.getShape().size()))
1021       nce.emplace_back(x);
1022   return nce;
1023 }
1024 
1025 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any
1026 /// allocatable direct components of the array elements with an unallocated
1027 /// status. Returns the temporary address as well as a callback to generate the
1028 /// temporary clean-up once it has been used. The clean-up will take care of
1029 /// deallocating all the element allocatable components that may have been
1030 /// allocated while using the temporary.
1031 static std::pair<mlir::Value,
1032                  std::function<void(mlir::PatternRewriter &rewriter)>>
1033 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter,
1034                   ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents,
1035                   mlir::Value shape) {
1036   mlir::Type baseType = load.getMemref().getType();
1037   llvm::SmallVector<mlir::Value> nonconstantExtents =
1038       findNonconstantExtents(baseType, extents);
1039   llvm::SmallVector<mlir::Value> typeParams =
1040       genArrayLoadTypeParameters(loc, rewriter, load);
1041   mlir::Value allocmem = rewriter.create<AllocMemOp>(
1042       loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents);
1043   mlir::Type eleType =
1044       fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType));
1045   if (fir::isRecordWithAllocatableMember(eleType)) {
1046     // The allocatable component descriptors need to be set to a clean
1047     // deallocated status before anything is done with them.
1048     mlir::Value box = rewriter.create<fir::EmboxOp>(
1049         loc, fir::BoxType::get(baseType), allocmem, shape,
1050         /*slice=*/mlir::Value{}, typeParams);
1051     auto module = load->getParentOfType<mlir::ModuleOp>();
1052     FirOpBuilder builder(rewriter, getKindMapping(module));
1053     runtime::genDerivedTypeInitialize(builder, loc, box);
1054     // Any allocatable component that may have been allocated must be
1055     // deallocated during the clean-up.
1056     auto cleanup = [=](mlir::PatternRewriter &r) {
1057       FirOpBuilder builder(r, getKindMapping(module));
1058       runtime::genDerivedTypeDestroy(builder, loc, box);
1059       r.create<FreeMemOp>(loc, allocmem);
1060     };
1061     return {allocmem, cleanup};
1062   }
1063   auto cleanup = [=](mlir::PatternRewriter &r) {
1064     r.create<FreeMemOp>(loc, allocmem);
1065   };
1066   return {allocmem, cleanup};
1067 }
1068 
1069 namespace {
1070 /// Conversion of fir.array_update and fir.array_modify Ops.
1071 /// If there is a conflict for the update, then we need to perform a
1072 /// copy-in/copy-out to preserve the original values of the array. If there is
1073 /// no conflict, then it is save to eschew making any copies.
1074 template <typename ArrayOp>
1075 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
1076 public:
1077   // TODO: Implement copy/swap semantics?
1078   explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
1079                                      const ArrayCopyAnalysis &a,
1080                                      const OperationUseMapT &m)
1081       : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
1082 
1083   /// The array_access, \p access, is to be to a cloned copy due to a potential
1084   /// conflict. Uses copy-in/copy-out semantics and not copy/swap.
1085   mlir::Value referenceToClone(mlir::Location loc,
1086                                mlir::PatternRewriter &rewriter,
1087                                ArrayOp access) const {
1088     LLVM_DEBUG(llvm::dbgs()
1089                << "generating copy-in/copy-out loops for " << access << '\n');
1090     auto *op = access.getOperation();
1091     auto *loadOp = useMap.lookup(op);
1092     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1093     auto eleTy = access.getType();
1094     rewriter.setInsertionPoint(loadOp);
1095     // Copy in.
1096     llvm::SmallVector<mlir::Value> extents;
1097     bool copyUsingSlice = false;
1098     auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1099                                               copyUsingSlice);
1100     auto [allocmem, genTempCleanUp] =
1101         allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1102     genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1103                                   load.getMemref(), shapeOp, load.getSlice(),
1104                                   load);
1105     // Generate the reference for the access.
1106     rewriter.setInsertionPoint(op);
1107     auto coor = genCoorOp(
1108         rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, shapeOp,
1109         copyUsingSlice ? mlir::Value{} : load.getSlice(), access.getIndices(),
1110         load, access->hasAttr(factory::attrFortranArrayOffsets()));
1111     // Copy out.
1112     auto *storeOp = useMap.lookup(loadOp);
1113     auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1114     rewriter.setInsertionPoint(storeOp);
1115     // Copy out.
1116     genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(),
1117                                    allocmem, shapeOp, store.getSlice(), load);
1118     genTempCleanUp(rewriter);
1119     return coor;
1120   }
1121 
1122   /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
1123   /// temp and the LHS if the analysis found potential overlaps between the RHS
1124   /// and LHS arrays. The element copy generator must be provided in \p
1125   /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
1126   /// Returns the address of the LHS element inside the loop and the LHS
1127   /// ArrayLoad result.
1128   std::pair<mlir::Value, mlir::Value>
1129   materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
1130                         ArrayOp update,
1131                         const std::function<void(mlir::Value)> &assignElement,
1132                         mlir::Type lhsEltRefType) const {
1133     auto *op = update.getOperation();
1134     auto *loadOp = useMap.lookup(op);
1135     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1136     LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
1137     if (analysis.hasPotentialConflict(loadOp)) {
1138       // If there is a conflict between the arrays, then we copy the lhs array
1139       // to a temporary, update the temporary, and copy the temporary back to
1140       // the lhs array. This yields Fortran's copy-in copy-out array semantics.
1141       LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
1142       rewriter.setInsertionPoint(loadOp);
1143       // Copy in.
1144       llvm::SmallVector<mlir::Value> extents;
1145       bool copyUsingSlice = false;
1146       auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1147                                                 copyUsingSlice);
1148       auto [allocmem, genTempCleanUp] =
1149           allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1150 
1151       genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1152                                     load.getMemref(), shapeOp, load.getSlice(),
1153                                     load);
1154       rewriter.setInsertionPoint(op);
1155       auto coor = genCoorOp(
1156           rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
1157           shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
1158           update.getIndices(), load,
1159           update->hasAttr(factory::attrFortranArrayOffsets()));
1160       assignElement(coor);
1161       auto *storeOp = useMap.lookup(loadOp);
1162       auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1163       rewriter.setInsertionPoint(storeOp);
1164       // Copy out.
1165       genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter,
1166                                      store.getMemref(), allocmem, shapeOp,
1167                                      store.getSlice(), load);
1168       genTempCleanUp(rewriter);
1169       return {coor, load.getResult()};
1170     }
1171     // Otherwise, when there is no conflict (a possible loop-carried
1172     // dependence), the lhs array can be updated in place.
1173     LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
1174     rewriter.setInsertionPoint(op);
1175     auto coorTy = getEleTy(load.getType());
1176     auto coor =
1177         genCoorOp(rewriter, loc, coorTy, lhsEltRefType, load.getMemref(),
1178                   load.getShape(), load.getSlice(), update.getIndices(), load,
1179                   update->hasAttr(factory::attrFortranArrayOffsets()));
1180     assignElement(coor);
1181     return {coor, load.getResult()};
1182   }
1183 
1184 protected:
1185   const ArrayCopyAnalysis &analysis;
1186   const OperationUseMapT &useMap;
1187 };
1188 
1189 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
1190 public:
1191   explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
1192                                  const ArrayCopyAnalysis &a,
1193                                  const OperationUseMapT &m)
1194       : ArrayUpdateConversionBase{ctx, a, m} {}
1195 
1196   mlir::LogicalResult
1197   matchAndRewrite(ArrayUpdateOp update,
1198                   mlir::PatternRewriter &rewriter) const override {
1199     auto loc = update.getLoc();
1200     auto assignElement = [&](mlir::Value coor) {
1201       auto input = update.getMerge();
1202       if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) {
1203         emitFatalError(loc, "array_update on references not supported");
1204       } else {
1205         rewriter.create<fir::StoreOp>(loc, input, coor);
1206       }
1207     };
1208     auto lhsEltRefType = toRefType(update.getMerge().getType());
1209     auto [_, lhsLoadResult] = materializeAssignment(
1210         loc, rewriter, update, assignElement, lhsEltRefType);
1211     update.replaceAllUsesWith(lhsLoadResult);
1212     rewriter.replaceOp(update, lhsLoadResult);
1213     return mlir::success();
1214   }
1215 };
1216 
1217 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
1218 public:
1219   explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
1220                                  const ArrayCopyAnalysis &a,
1221                                  const OperationUseMapT &m)
1222       : ArrayUpdateConversionBase{ctx, a, m} {}
1223 
1224   mlir::LogicalResult
1225   matchAndRewrite(ArrayModifyOp modify,
1226                   mlir::PatternRewriter &rewriter) const override {
1227     auto loc = modify.getLoc();
1228     auto assignElement = [](mlir::Value) {
1229       // Assignment already materialized by lowering using lhs element address.
1230     };
1231     auto lhsEltRefType = modify.getResult(0).getType();
1232     auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
1233         loc, rewriter, modify, assignElement, lhsEltRefType);
1234     modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1235     rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1236     return mlir::success();
1237   }
1238 };
1239 
1240 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
1241 public:
1242   explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
1243                                 const OperationUseMapT &m)
1244       : OpRewritePattern{ctx}, useMap{m} {}
1245 
1246   mlir::LogicalResult
1247   matchAndRewrite(ArrayFetchOp fetch,
1248                   mlir::PatternRewriter &rewriter) const override {
1249     auto *op = fetch.getOperation();
1250     rewriter.setInsertionPoint(op);
1251     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1252     auto loc = fetch.getLoc();
1253     auto coor = genCoorOp(
1254         rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()),
1255         load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(),
1256         load, fetch->hasAttr(factory::attrFortranArrayOffsets()));
1257     if (isa_ref_type(fetch.getType()))
1258       rewriter.replaceOp(fetch, coor);
1259     else
1260       rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
1261     return mlir::success();
1262   }
1263 
1264 private:
1265   const OperationUseMapT &useMap;
1266 };
1267 
1268 /// As array_access op is like an array_fetch op, except that it does not imply
1269 /// a load op. (It operates in the reference domain.)
1270 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> {
1271 public:
1272   explicit ArrayAccessConversion(mlir::MLIRContext *ctx,
1273                                  const ArrayCopyAnalysis &a,
1274                                  const OperationUseMapT &m)
1275       : ArrayUpdateConversionBase{ctx, a, m} {}
1276 
1277   mlir::LogicalResult
1278   matchAndRewrite(ArrayAccessOp access,
1279                   mlir::PatternRewriter &rewriter) const override {
1280     auto *op = access.getOperation();
1281     auto loc = access.getLoc();
1282     if (analysis.inAmendAccessSet(op)) {
1283       // This array_access is associated with an array_amend and there is a
1284       // conflict. Make a copy to store into.
1285       auto result = referenceToClone(loc, rewriter, access);
1286       access.replaceAllUsesWith(result);
1287       rewriter.replaceOp(access, result);
1288       return mlir::success();
1289     }
1290     rewriter.setInsertionPoint(op);
1291     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1292     auto coor = genCoorOp(
1293         rewriter, loc, getEleTy(load.getType()), toRefType(access.getType()),
1294         load.getMemref(), load.getShape(), load.getSlice(), access.getIndices(),
1295         load, access->hasAttr(factory::attrFortranArrayOffsets()));
1296     rewriter.replaceOp(access, coor);
1297     return mlir::success();
1298   }
1299 };
1300 
1301 /// An array_amend op is a marker to record which array access is being used to
1302 /// update an array value. After this pass runs, an array_amend has no
1303 /// semantics. We rewrite these to undefined values here to remove them while
1304 /// preserving SSA form.
1305 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> {
1306 public:
1307   explicit ArrayAmendConversion(mlir::MLIRContext *ctx)
1308       : OpRewritePattern{ctx} {}
1309 
1310   mlir::LogicalResult
1311   matchAndRewrite(ArrayAmendOp amend,
1312                   mlir::PatternRewriter &rewriter) const override {
1313     auto *op = amend.getOperation();
1314     rewriter.setInsertionPoint(op);
1315     auto loc = amend.getLoc();
1316     auto undef = rewriter.create<UndefOp>(loc, amend.getType());
1317     rewriter.replaceOp(amend, undef.getResult());
1318     return mlir::success();
1319   }
1320 };
1321 
1322 class ArrayValueCopyConverter
1323     : public ArrayValueCopyBase<ArrayValueCopyConverter> {
1324 public:
1325   void runOnOperation() override {
1326     auto func = getOperation();
1327     LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
1328                             << func.getName() << "'\n");
1329     auto *context = &getContext();
1330 
1331     // Perform the conflict analysis.
1332     const auto &analysis = getAnalysis<ArrayCopyAnalysis>();
1333     const auto &useMap = analysis.getUseMap();
1334 
1335     mlir::RewritePatternSet patterns1(context);
1336     patterns1.insert<ArrayFetchConversion>(context, useMap);
1337     patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
1338     patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
1339     patterns1.insert<ArrayAccessConversion>(context, analysis, useMap);
1340     patterns1.insert<ArrayAmendConversion>(context);
1341     mlir::ConversionTarget target(*context);
1342     target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
1343                            mlir::arith::ArithmeticDialect,
1344                            mlir::func::FuncDialect>();
1345     target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp,
1346                         ArrayUpdateOp, ArrayModifyOp>();
1347     // Rewrite the array fetch and array update ops.
1348     if (mlir::failed(
1349             mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
1350       mlir::emitError(mlir::UnknownLoc::get(context),
1351                       "failure in array-value-copy pass, phase 1");
1352       signalPassFailure();
1353     }
1354 
1355     mlir::RewritePatternSet patterns2(context);
1356     patterns2.insert<ArrayLoadConversion>(context);
1357     patterns2.insert<ArrayMergeStoreConversion>(context);
1358     target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
1359     if (mlir::failed(
1360             mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
1361       mlir::emitError(mlir::UnknownLoc::get(context),
1362                       "failure in array-value-copy pass, phase 2");
1363       signalPassFailure();
1364     }
1365   }
1366 };
1367 } // namespace
1368 
1369 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() {
1370   return std::make_unique<ArrayValueCopyConverter>();
1371 }
1372