xref: /llvm-project/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/BoxValue.h"
10 #include "flang/Optimizer/Builder/FIRBuilder.h"
11 #include "flang/Optimizer/Builder/Factory.h"
12 #include "flang/Optimizer/Builder/Runtime/Derived.h"
13 #include "flang/Optimizer/Builder/Todo.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
16 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
17 #include "flang/Optimizer/Transforms/Passes.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/Support/Debug.h"
22 
23 namespace fir {
24 #define GEN_PASS_DEF_ARRAYVALUECOPY
25 #include "flang/Optimizer/Transforms/Passes.h.inc"
26 } // namespace fir
27 
28 #define DEBUG_TYPE "flang-array-value-copy"
29 
30 using namespace fir;
31 using namespace mlir;
32 
33 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
34 
35 namespace {
36 
37 /// Array copy analysis.
38 /// Perform an interference analysis between array values.
39 ///
40 /// Lowering will generate a sequence of the following form.
41 /// ```mlir
42 ///   %a_1 = fir.array_load %array_1(%shape) : ...
43 ///   ...
44 ///   %a_j = fir.array_load %array_j(%shape) : ...
45 ///   ...
46 ///   %a_n = fir.array_load %array_n(%shape) : ...
47 ///     ...
48 ///     %v_i = fir.array_fetch %a_i, ...
49 ///     %a_j1 = fir.array_update %a_j, ...
50 ///     ...
51 ///   fir.array_merge_store %a_j, %a_jn to %array_j : ...
52 /// ```
53 ///
54 /// The analysis is to determine if there are any conflicts. A conflict is when
55 /// one the following cases occurs.
56 ///
57 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
58 /// loaded from the same array memory reference (array_j) but with a different
59 /// shape as the other array values a_i, where i != j. [Possible overlapping
60 /// arrays.]
61 ///
62 /// 2. There is either an array_fetch or array_update of a_j with a different
63 /// set of index values. [Possible loop-carried dependence.]
64 ///
65 /// If none of the array values overlap in storage and the accesses are not
66 /// loop-carried, then the arrays are conflict-free and no copies are required.
67 class ArrayCopyAnalysisBase {
68 public:
69   using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
70   using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
71   using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>;
72   using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>;
73 
ArrayCopyAnalysisBase(mlir::Operation * op,bool optimized)74   ArrayCopyAnalysisBase(mlir::Operation *op, bool optimized)
75       : operation{op}, optimizeConflicts(optimized) {
76     construct(op);
77   }
78   virtual ~ArrayCopyAnalysisBase() = default;
79 
getOperation() const80   mlir::Operation *getOperation() const { return operation; }
81 
82   /// Return true iff the `array_merge_store` has potential conflicts.
hasPotentialConflict(mlir::Operation * op) const83   bool hasPotentialConflict(mlir::Operation *op) const {
84     LLVM_DEBUG(llvm::dbgs()
85                << "looking for a conflict on " << *op
86                << " and the set has a total of " << conflicts.size() << '\n');
87     return conflicts.contains(op);
88   }
89 
90   /// Return the use map.
91   /// The use map maps array access, amend, fetch and update operations back to
92   /// the array load that is the original source of the array value.
93   /// It maps an array_load to an array_merge_store, if and only if the loaded
94   /// array value has pending modifications to be merged.
getUseMap() const95   const OperationUseMapT &getUseMap() const { return useMap; }
96 
97   /// Return the set of array_access ops directly associated with array_amend
98   /// ops.
inAmendAccessSet(mlir::Operation * op) const99   bool inAmendAccessSet(mlir::Operation *op) const {
100     return amendAccesses.count(op);
101   }
102 
103   /// For ArrayLoad `load`, return the transitive set of all OpOperands.
getLoadUseSet(mlir::Operation * load) const104   UseSetT getLoadUseSet(mlir::Operation *load) const {
105     assert(loadMapSets.count(load) && "analysis missed an array load?");
106     return loadMapSets.lookup(load);
107   }
108 
109   void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions,
110                      ArrayLoadOp load);
111 
112 private:
113   void construct(mlir::Operation *topLevelOp);
114 
115   mlir::Operation *operation; // operation that analysis ran upon
116   ConflictSetT conflicts;     // set of conflicts (loads and merge stores)
117   OperationUseMapT useMap;
118   LoadMapSetsT loadMapSets;
119   // Set of array_access ops associated with array_amend ops.
120   AmendAccessSetT amendAccesses;
121   bool optimizeConflicts;
122 };
123 
124 // Optimized array copy analysis that takes into account Fortran
125 // variable attributes to prove that no conflict is possible
126 // and reduce the number of temporary arrays.
127 class ArrayCopyAnalysisOptimized : public ArrayCopyAnalysisBase {
128 public:
129   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysisOptimized)
130 
ArrayCopyAnalysisOptimized(mlir::Operation * op)131   ArrayCopyAnalysisOptimized(mlir::Operation *op)
132       : ArrayCopyAnalysisBase(op, /*optimized=*/true) {}
133 };
134 
135 // Unoptimized array copy analysis used at O0.
136 class ArrayCopyAnalysis : public ArrayCopyAnalysisBase {
137 public:
138   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis)
139 
ArrayCopyAnalysis(mlir::Operation * op)140   ArrayCopyAnalysis(mlir::Operation *op)
141       : ArrayCopyAnalysisBase(op, /*optimized=*/false) {}
142 };
143 } // namespace
144 
145 namespace {
146 /// Helper class to collect all array operations that produced an array value.
147 class ReachCollector {
148 public:
ReachCollector(llvm::SmallVectorImpl<mlir::Operation * > & reach,mlir::Region * loopRegion)149   ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
150                  mlir::Region *loopRegion)
151       : reach{reach}, loopRegion{loopRegion} {}
152 
collectArrayMentionFrom(mlir::Operation * op,mlir::ValueRange range)153   void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) {
154     if (range.empty()) {
155       collectArrayMentionFrom(op, mlir::Value{});
156       return;
157     }
158     for (mlir::Value v : range)
159       collectArrayMentionFrom(v);
160   }
161 
162   // Collect all the array_access ops in `block`. This recursively looks into
163   // blocks in ops with regions.
164   // FIXME: This is temporarily relying on the array_amend appearing in a
165   // do_loop Region.  This phase ordering assumption can be eliminated by using
166   // dominance information to find the array_access ops or by scanning the
167   // transitive closure of the amending array_access's users and the defs that
168   // reach them.
collectAccesses(llvm::SmallVector<ArrayAccessOp> & result,mlir::Block * block)169   void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result,
170                        mlir::Block *block) {
171     for (auto &op : *block) {
172       if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) {
173         LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n');
174         result.push_back(access);
175         continue;
176       }
177       for (auto &region : op.getRegions())
178         for (auto &bb : region.getBlocks())
179           collectAccesses(result, &bb);
180     }
181   }
182 
collectArrayMentionFrom(mlir::Operation * op,mlir::Value val)183   void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) {
184     // `val` is defined by an Op, process the defining Op.
185     // If `val` is defined by a region containing Op, we want to drill down
186     // and through that Op's region(s).
187     LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
188     auto popFn = [&](auto rop) {
189       assert(val && "op must have a result value");
190       auto resNum = mlir::cast<mlir::OpResult>(val).getResultNumber();
191       llvm::SmallVector<mlir::Value> results;
192       rop.resultToSourceOps(results, resNum);
193       for (auto u : results)
194         collectArrayMentionFrom(u);
195     };
196     if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) {
197       popFn(rop);
198       return;
199     }
200     if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) {
201       popFn(rop);
202       return;
203     }
204     if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
205       popFn(rop);
206       return;
207     }
208     if (auto box = mlir::dyn_cast<EmboxOp>(op)) {
209       for (auto *user : box.getMemref().getUsers())
210         if (user != op)
211           collectArrayMentionFrom(user, user->getResults());
212       return;
213     }
214     if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
215       if (opIsInsideLoops(mergeStore))
216         collectArrayMentionFrom(mergeStore.getSequence());
217       return;
218     }
219 
220     if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
221       // Look for any stores inside the loops, and collect an array operation
222       // that produced the value being stored to it.
223       for (auto *user : op->getUsers())
224         if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
225           if (opIsInsideLoops(store))
226             collectArrayMentionFrom(store.getValue());
227       return;
228     }
229 
230     // Scan the uses of amend's memref
231     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) {
232       reach.push_back(op);
233       llvm::SmallVector<ArrayAccessOp> accesses;
234       collectAccesses(accesses, op->getBlock());
235       for (auto access : accesses)
236         collectArrayMentionFrom(access.getResult());
237     }
238 
239     // Otherwise, Op does not contain a region so just chase its operands.
240     if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp,
241                   ArrayFetchOp>(op)) {
242       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
243       reach.push_back(op);
244     }
245 
246     // Include all array_access ops using an array_load.
247     if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op))
248       for (auto *user : arrLd.getResult().getUsers())
249         if (mlir::isa<ArrayAccessOp>(user)) {
250           LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n");
251           reach.push_back(user);
252         }
253 
254     // Array modify assignment is performed on the result. So the analysis must
255     // look at the what is done with the result.
256     if (mlir::isa<ArrayModifyOp>(op))
257       for (auto *user : op->getResult(0).getUsers())
258         followUsers(user);
259 
260     if (mlir::isa<fir::CallOp>(op)) {
261       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
262       reach.push_back(op);
263     }
264 
265     for (auto u : op->getOperands())
266       collectArrayMentionFrom(u);
267   }
268 
collectArrayMentionFrom(mlir::BlockArgument ba)269   void collectArrayMentionFrom(mlir::BlockArgument ba) {
270     auto *parent = ba.getOwner()->getParentOp();
271     // If inside an Op holding a region, the block argument corresponds to an
272     // argument passed to the containing Op.
273     auto popFn = [&](auto rop) {
274       collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
275     };
276     if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
277       popFn(rop);
278       return;
279     }
280     if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
281       popFn(rop);
282       return;
283     }
284     // Otherwise, a block argument is provided via the pred blocks.
285     for (auto *pred : ba.getOwner()->getPredecessors()) {
286       auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
287       collectArrayMentionFrom(u);
288     }
289   }
290 
291   // Recursively trace operands to find all array operations relating to the
292   // values merged.
collectArrayMentionFrom(mlir::Value val)293   void collectArrayMentionFrom(mlir::Value val) {
294     if (!val || visited.contains(val))
295       return;
296     visited.insert(val);
297 
298     // Process a block argument.
299     if (auto ba = mlir::dyn_cast<mlir::BlockArgument>(val)) {
300       collectArrayMentionFrom(ba);
301       return;
302     }
303 
304     // Process an Op.
305     if (auto *op = val.getDefiningOp()) {
306       collectArrayMentionFrom(op, val);
307       return;
308     }
309 
310     emitFatalError(val.getLoc(), "unhandled value");
311   }
312 
313   /// Return all ops that produce the array value that is stored into the
314   /// `array_merge_store`.
reachingValues(llvm::SmallVectorImpl<mlir::Operation * > & reach,mlir::Value seq)315   static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
316                              mlir::Value seq) {
317     reach.clear();
318     mlir::Region *loopRegion = nullptr;
319     if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp()))
320       loopRegion = &doLoop->getRegion(0);
321     ReachCollector collector(reach, loopRegion);
322     collector.collectArrayMentionFrom(seq);
323   }
324 
325 private:
326   /// Is \op inside the loop nest region ?
327   /// FIXME: replace this structural dependence with graph properties.
opIsInsideLoops(mlir::Operation * op) const328   bool opIsInsideLoops(mlir::Operation *op) const {
329     auto *region = op->getParentRegion();
330     while (region) {
331       if (region == loopRegion)
332         return true;
333       region = region->getParentRegion();
334     }
335     return false;
336   }
337 
338   /// Recursively trace the use of an operation results, calling
339   /// collectArrayMentionFrom on the direct and indirect user operands.
followUsers(mlir::Operation * op)340   void followUsers(mlir::Operation *op) {
341     for (auto userOperand : op->getOperands())
342       collectArrayMentionFrom(userOperand);
343     // Go through potential converts/coordinate_op.
344     for (auto indirectUser : op->getUsers())
345       followUsers(indirectUser);
346   }
347 
348   llvm::SmallVectorImpl<mlir::Operation *> &reach;
349   llvm::SmallPtrSet<mlir::Value, 16> visited;
350   /// Region of the loops nest that produced the array value.
351   mlir::Region *loopRegion;
352 };
353 } // namespace
354 
355 /// Find all the array operations that access the array value that is loaded by
356 /// the array load operation, `load`.
arrayMentions(llvm::SmallVectorImpl<mlir::Operation * > & mentions,ArrayLoadOp load)357 void ArrayCopyAnalysisBase::arrayMentions(
358     llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) {
359   mentions.clear();
360   auto lmIter = loadMapSets.find(load);
361   if (lmIter != loadMapSets.end()) {
362     for (auto *opnd : lmIter->second) {
363       auto *owner = opnd->getOwner();
364       if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
365                     ArrayModifyOp>(owner))
366         mentions.push_back(owner);
367     }
368     return;
369   }
370 
371   UseSetT visited;
372   llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
373 
374   auto appendToQueue = [&](mlir::Value val) {
375     for (auto &use : val.getUses())
376       if (!visited.count(&use)) {
377         visited.insert(&use);
378         queue.push_back(&use);
379       }
380   };
381 
382   // Build the set of uses of `original`.
383   // let USES = { uses of original fir.load }
384   appendToQueue(load);
385 
386   // Process the worklist until done.
387   while (!queue.empty()) {
388     mlir::OpOperand *operand = queue.pop_back_val();
389     mlir::Operation *owner = operand->getOwner();
390     if (!owner)
391       continue;
392     auto structuredLoop = [&](auto ro) {
393       if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
394         int64_t arg = blockArg.getArgNumber();
395         mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1);
396         appendToQueue(output);
397         appendToQueue(blockArg);
398       }
399     };
400     // TODO: this need to be updated to use the control-flow interface.
401     auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
402       if (operands.empty())
403         return;
404 
405       // Check if this operand is within the range.
406       unsigned operandIndex = operand->getOperandNumber();
407       unsigned operandsStart = operands.getBeginOperandIndex();
408       if (operandIndex < operandsStart ||
409           operandIndex >= (operandsStart + operands.size()))
410         return;
411 
412       // Index the successor.
413       unsigned argIndex = operandIndex - operandsStart;
414       appendToQueue(dest->getArgument(argIndex));
415     };
416     // Thread uses into structured loop bodies and return value uses.
417     if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
418       structuredLoop(ro);
419     } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
420       structuredLoop(ro);
421     } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
422       // Thread any uses of fir.if that return the marked array value.
423       mlir::Operation *parent = rs->getParentRegion()->getParentOp();
424       if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent))
425         appendToQueue(ifOp.getResult(operand->getOperandNumber()));
426     } else if (mlir::isa<ArrayFetchOp>(owner)) {
427       // Keep track of array value fetches.
428       LLVM_DEBUG(llvm::dbgs()
429                  << "add fetch {" << *owner << "} to array value set\n");
430       mentions.push_back(owner);
431     } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
432       // Keep track of array value updates and thread the return value uses.
433       LLVM_DEBUG(llvm::dbgs()
434                  << "add update {" << *owner << "} to array value set\n");
435       mentions.push_back(owner);
436       appendToQueue(update.getResult());
437     } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
438       // Keep track of array value modification and thread the return value
439       // uses.
440       LLVM_DEBUG(llvm::dbgs()
441                  << "add modify {" << *owner << "} to array value set\n");
442       mentions.push_back(owner);
443       appendToQueue(update.getResult(1));
444     } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) {
445       mentions.push_back(owner);
446     } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) {
447       mentions.push_back(owner);
448       appendToQueue(amend.getResult());
449     } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
450       branchOp(br.getDest(), br.getDestOperands());
451     } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
452       branchOp(br.getTrueDest(), br.getTrueOperands());
453       branchOp(br.getFalseDest(), br.getFalseOperands());
454     } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
455       // do nothing
456     } else {
457       llvm::report_fatal_error("array value reached unexpected op");
458     }
459   }
460   loadMapSets.insert({load, visited});
461 }
462 
hasPointerType(mlir::Type type)463 static bool hasPointerType(mlir::Type type) {
464   if (auto boxTy = mlir::dyn_cast<BoxType>(type))
465     type = boxTy.getEleTy();
466   return mlir::isa<fir::PointerType>(type);
467 }
468 
469 // This is a NF performance hack. It makes a simple test that the slices of the
470 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive.
mutuallyExclusiveSliceRange(ArrayLoadOp ld,ArrayMergeStoreOp st)471 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) {
472   // If the same array_load, then no further testing is warranted.
473   if (ld.getResult() == st.getOriginal())
474     return false;
475 
476   auto getSliceOp = [](mlir::Value val) -> SliceOp {
477     if (!val)
478       return {};
479     auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp());
480     if (!sliceOp)
481       return {};
482     return sliceOp;
483   };
484 
485   auto ldSlice = getSliceOp(ld.getSlice());
486   auto stSlice = getSliceOp(st.getSlice());
487   if (!ldSlice || !stSlice)
488     return false;
489 
490   // Resign on subobject slices.
491   if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() ||
492       !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty())
493     return false;
494 
495   // Crudely test that the two slices do not overlap by looking for the
496   // following general condition. If the slices look like (i:j) and (j+1:k) then
497   // these ranges do not overlap. The addend must be a constant.
498   auto ldTriples = ldSlice.getTriples();
499   auto stTriples = stSlice.getTriples();
500   const auto size = ldTriples.size();
501   if (size != stTriples.size())
502     return false;
503 
504   auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
505     auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
506       auto *op = v.getDefiningOp();
507       while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op))
508         op = conv.getValue().getDefiningOp();
509       return op;
510     };
511 
512     auto isPositiveConstant = [](mlir::Value v) -> bool {
513       if (auto conOp =
514               mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
515         if (auto iattr = mlir::dyn_cast<mlir::IntegerAttr>(conOp.getValue()))
516           return iattr.getInt() > 0;
517       return false;
518     };
519 
520     auto *op1 = removeConvert(v1);
521     auto *op2 = removeConvert(v2);
522     if (!op1 || !op2)
523       return false;
524     if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
525       if ((addi.getLhs().getDefiningOp() == op1 &&
526            isPositiveConstant(addi.getRhs())) ||
527           (addi.getRhs().getDefiningOp() == op1 &&
528            isPositiveConstant(addi.getLhs())))
529         return true;
530     if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
531       if (subi.getLhs().getDefiningOp() == op2 &&
532           isPositiveConstant(subi.getRhs()))
533         return true;
534     return false;
535   };
536 
537   for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) {
538     // If both are loop invariant, skip to the next triple.
539     if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) &&
540         mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) {
541       // Unless either is a vector index, then be conservative.
542       if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) ||
543           mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp()))
544         return false;
545       continue;
546     }
547     // If identical, skip to the next triple.
548     if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] &&
549         ldTriples[i + 2] == stTriples[i + 2])
550       continue;
551     // If ubound and lbound are the same with a constant offset, skip to the
552     // next triple.
553     if (displacedByConstant(ldTriples[i + 1], stTriples[i]) ||
554         displacedByConstant(stTriples[i + 1], ldTriples[i]))
555       continue;
556     return false;
557   }
558   LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld
559                           << " and " << st << ", which is not a conflict\n");
560   return true;
561 }
562 
563 /// Is there a conflict between the array value that was updated and to be
564 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
565 /// the updated value?
566 /// If `optimize` is true, use the variable attributes to prove that
567 /// there is no conflict.
conflictOnLoad(llvm::ArrayRef<mlir::Operation * > reach,ArrayMergeStoreOp st,bool optimize)568 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
569                            ArrayMergeStoreOp st, bool optimize) {
570   mlir::Value load;
571   mlir::Value addr = st.getMemref();
572   const bool storeHasPointerType = hasPointerType(addr.getType());
573   for (auto *op : reach)
574     if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) {
575       mlir::Type ldTy = ld.getMemref().getType();
576       auto globalOpName = mlir::OperationName(fir::GlobalOp::getOperationName(),
577                                               ld.getContext());
578       if (ld.getMemref() == addr) {
579         if (mutuallyExclusiveSliceRange(ld, st))
580           continue;
581         if (ld.getResult() != st.getOriginal())
582           return true;
583         if (load) {
584           // TODO: extend this to allow checking if the first `load` and this
585           // `ld` are mutually exclusive accesses but not identical.
586           return true;
587         }
588         load = ld;
589       } else if (storeHasPointerType) {
590         if (optimize && !hasPointerType(ldTy) &&
591             !valueMayHaveFirAttributes(
592                 ld.getMemref(),
593                 {getTargetAttrName(),
594                  fir::GlobalOp::getTargetAttrName(globalOpName).strref()}))
595           continue;
596 
597         return true;
598       } else if (hasPointerType(ldTy)) {
599         if (optimize && !storeHasPointerType &&
600             !valueMayHaveFirAttributes(
601                 addr,
602                 {getTargetAttrName(),
603                  fir::GlobalOp::getTargetAttrName(globalOpName).strref()}))
604           continue;
605 
606         return true;
607       }
608       // TODO: Check if types can also allow ruling out some cases. For now,
609       // the fact that equivalences is using pointer attribute to enforce
610       // aliasing is preventing any attempt to do so, and in general, it may
611       // be wrong to use this if any of the types is a complex or a derived
612       // for which it is possible to create a pointer to a part with a
613       // different type than the whole, although this deserve some more
614       // investigation because existing compiler behavior seem to diverge
615       // here.
616     }
617   return false;
618 }
619 
620 /// Is there an access vector conflict on the array being merged into? If the
621 /// access vectors diverge, then assume that there are potentially overlapping
622 /// loop-carried references.
conflictOnMerge(llvm::ArrayRef<mlir::Operation * > mentions)623 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) {
624   if (mentions.size() < 2)
625     return false;
626   llvm::SmallVector<mlir::Value> indices;
627   LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size()
628                           << " mentions on the list\n");
629   bool valSeen = false;
630   bool refSeen = false;
631   for (auto *op : mentions) {
632     llvm::SmallVector<mlir::Value> compareVector;
633     if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
634       valSeen = true;
635       if (indices.empty()) {
636         indices = u.getIndices();
637         continue;
638       }
639       compareVector = u.getIndices();
640     } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
641       valSeen = true;
642       if (indices.empty()) {
643         indices = f.getIndices();
644         continue;
645       }
646       compareVector = f.getIndices();
647     } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
648       valSeen = true;
649       if (indices.empty()) {
650         indices = f.getIndices();
651         continue;
652       }
653       compareVector = f.getIndices();
654     } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) {
655       refSeen = true;
656       if (indices.empty()) {
657         indices = f.getIndices();
658         continue;
659       }
660       compareVector = f.getIndices();
661     } else if (mlir::isa<ArrayAmendOp>(op)) {
662       refSeen = true;
663       continue;
664     } else {
665       mlir::emitError(op->getLoc(), "unexpected operation in analysis");
666     }
667     if (compareVector.size() != indices.size() ||
668         llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) {
669           return std::get<0>(pair) != std::get<1>(pair);
670         }))
671       return true;
672     LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
673   }
674   return valSeen && refSeen;
675 }
676 
677 /// With element-by-reference semantics, an amended array with more than once
678 /// access to the same loaded array are conservatively considered a conflict.
679 /// Note: the array copy can still be eliminated in subsequent optimizations.
conflictOnReference(llvm::ArrayRef<mlir::Operation * > mentions)680 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) {
681   LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size()
682                           << '\n');
683   if (mentions.size() < 3)
684     return false;
685   unsigned amendCount = 0;
686   unsigned accessCount = 0;
687   for (auto *op : mentions) {
688     if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) {
689       LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n");
690       return true;
691     }
692     if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) {
693       LLVM_DEBUG(llvm::dbgs()
694                  << "conflict: multiple accesses of array value\n");
695       return true;
696     }
697     if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) {
698       LLVM_DEBUG(llvm::dbgs()
699                  << "conflict: array value has both uses by-value and uses "
700                     "by-reference. conservative assumption.\n");
701       return true;
702     }
703   }
704   return false;
705 }
706 
707 static mlir::Operation *
amendingAccess(llvm::ArrayRef<mlir::Operation * > mentions)708 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) {
709   for (auto *op : mentions)
710     if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op))
711       return amend.getMemref().getDefiningOp();
712   return {};
713 }
714 
715 // Are any conflicts present? The conflicts detected here are described above.
conflictDetected(llvm::ArrayRef<mlir::Operation * > reach,llvm::ArrayRef<mlir::Operation * > mentions,ArrayMergeStoreOp st,bool optimize)716 static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
717                              llvm::ArrayRef<mlir::Operation *> mentions,
718                              ArrayMergeStoreOp st, bool optimize) {
719   return conflictOnLoad(reach, st, optimize) || conflictOnMerge(mentions);
720 }
721 
722 // Assume that any call to a function that uses host-associations will be
723 // modifying the output array.
724 static bool
conservativeCallConflict(llvm::ArrayRef<mlir::Operation * > reaches)725 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) {
726   return llvm::any_of(reaches, [](mlir::Operation *op) {
727     if (auto call = mlir::dyn_cast<fir::CallOp>(op))
728       if (auto callee = mlir::dyn_cast<mlir::SymbolRefAttr>(
729               call.getCallableForCallee())) {
730         auto module = op->getParentOfType<mlir::ModuleOp>();
731         return isInternalProcedure(
732             module.lookupSymbol<mlir::func::FuncOp>(callee));
733       }
734     return false;
735   });
736 }
737 
738 /// Constructor of the array copy analysis.
739 /// This performs the analysis and saves the intermediate results.
construct(mlir::Operation * topLevelOp)740 void ArrayCopyAnalysisBase::construct(mlir::Operation *topLevelOp) {
741   topLevelOp->walk([&](Operation *op) {
742     if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
743       llvm::SmallVector<mlir::Operation *> values;
744       ReachCollector::reachingValues(values, st.getSequence());
745       bool callConflict = conservativeCallConflict(values);
746       llvm::SmallVector<mlir::Operation *> mentions;
747       arrayMentions(mentions,
748                     mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp()));
749       bool conflict = conflictDetected(values, mentions, st, optimizeConflicts);
750       bool refConflict = conflictOnReference(mentions);
751       if (callConflict || conflict || refConflict) {
752         LLVM_DEBUG(llvm::dbgs()
753                    << "CONFLICT: copies required for " << st << '\n'
754                    << "   adding conflicts on: " << *op << " and "
755                    << st.getOriginal() << '\n');
756         conflicts.insert(op);
757         conflicts.insert(st.getOriginal().getDefiningOp());
758         if (auto *access = amendingAccess(mentions))
759           amendAccesses.insert(access);
760       }
761       auto *ld = st.getOriginal().getDefiningOp();
762       LLVM_DEBUG(llvm::dbgs()
763                  << "map: adding {" << *ld << " -> " << st << "}\n");
764       useMap.insert({ld, op});
765     } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
766       llvm::SmallVector<mlir::Operation *> mentions;
767       arrayMentions(mentions, load);
768       LLVM_DEBUG(llvm::dbgs() << "process load: " << load
769                               << ", mentions: " << mentions.size() << '\n');
770       for (auto *acc : mentions) {
771         LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n');
772         if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
773                       ArrayModifyOp>(acc)) {
774           if (useMap.count(acc)) {
775             mlir::emitError(
776                 load.getLoc(),
777                 "The parallel semantics of multiple array_merge_stores per "
778                 "array_load are not supported.");
779             continue;
780           }
781           LLVM_DEBUG(llvm::dbgs()
782                      << "map: adding {" << *acc << "} -> {" << load << "}\n");
783           useMap.insert({acc, op});
784         }
785       }
786     }
787   });
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // Conversions for converting out of array value form.
792 //===----------------------------------------------------------------------===//
793 
794 namespace {
795 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
796 public:
797   using OpRewritePattern::OpRewritePattern;
798 
799   llvm::LogicalResult
matchAndRewrite(ArrayLoadOp load,mlir::PatternRewriter & rewriter) const800   matchAndRewrite(ArrayLoadOp load,
801                   mlir::PatternRewriter &rewriter) const override {
802     LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
803     rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
804     return mlir::success();
805   }
806 };
807 
808 class ArrayMergeStoreConversion
809     : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
810 public:
811   using OpRewritePattern::OpRewritePattern;
812 
813   llvm::LogicalResult
matchAndRewrite(ArrayMergeStoreOp store,mlir::PatternRewriter & rewriter) const814   matchAndRewrite(ArrayMergeStoreOp store,
815                   mlir::PatternRewriter &rewriter) const override {
816     LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
817     rewriter.eraseOp(store);
818     return mlir::success();
819   }
820 };
821 } // namespace
822 
getEleTy(mlir::Type ty)823 static mlir::Type getEleTy(mlir::Type ty) {
824   auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty));
825   // FIXME: keep ptr/heap/ref information.
826   return ReferenceType::get(eleTy);
827 }
828 
829 // This is an unsafe way to deduce this (won't be true in internal
830 // procedure or inside select-rank for assumed-size). Only here to satisfy
831 // legacy code until removed.
isAssumedSize(llvm::SmallVectorImpl<mlir::Value> & extents)832 static bool isAssumedSize(llvm::SmallVectorImpl<mlir::Value> &extents) {
833   if (extents.empty())
834     return false;
835   auto cstLen = fir::getIntIfConstant(extents.back());
836   return cstLen.has_value() && *cstLen == -1;
837 }
838 
839 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
getAdjustedExtents(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayLoadOp arrLoad,llvm::SmallVectorImpl<mlir::Value> & result,mlir::Value shape)840 static bool getAdjustedExtents(mlir::Location loc,
841                                mlir::PatternRewriter &rewriter,
842                                ArrayLoadOp arrLoad,
843                                llvm::SmallVectorImpl<mlir::Value> &result,
844                                mlir::Value shape) {
845   bool copyUsingSlice = false;
846   auto *shapeOp = shape.getDefiningOp();
847   if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) {
848     auto e = s.getExtents();
849     result.insert(result.end(), e.begin(), e.end());
850   } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) {
851     auto e = s.getExtents();
852     result.insert(result.end(), e.begin(), e.end());
853   } else {
854     emitFatalError(loc, "not a fir.shape/fir.shape_shift op");
855   }
856   auto idxTy = rewriter.getIndexType();
857   if (isAssumedSize(result)) {
858     // Use slice information to compute the extent of the column.
859     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
860     mlir::Value size = one;
861     if (mlir::Value sliceArg = arrLoad.getSlice()) {
862       if (auto sliceOp =
863               mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) {
864         auto triples = sliceOp.getTriples();
865         const std::size_t tripleSize = triples.size();
866         auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
867         FirOpBuilder builder(rewriter, module);
868         size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3],
869                                             triples[tripleSize - 2],
870                                             triples[tripleSize - 1], idxTy);
871         copyUsingSlice = true;
872       }
873     }
874     result[result.size() - 1] = size;
875   }
876   return copyUsingSlice;
877 }
878 
879 /// Place the extents of the array load, \p arrLoad, into \p result and
880 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is
881 /// loading a `!fir.box`, code will be generated to read the extents from the
882 /// boxed value, and the retunred shape Op will be built with the extents read
883 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or
884 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true
885 /// if slicing of the output array is to be done in the copy-in/copy-out rather
886 /// than in the elemental computation step.
getOrReadExtentsAndShapeOp(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayLoadOp arrLoad,llvm::SmallVectorImpl<mlir::Value> & result,bool & copyUsingSlice)887 static mlir::Value getOrReadExtentsAndShapeOp(
888     mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad,
889     llvm::SmallVectorImpl<mlir::Value> &result, bool &copyUsingSlice) {
890   assert(result.empty());
891   if (arrLoad->hasAttr(fir::getOptionalAttrName()))
892     fir::emitFatalError(
893         loc, "shapes from array load of OPTIONAL arrays must not be used");
894   if (auto boxTy = mlir::dyn_cast<BoxType>(arrLoad.getMemref().getType())) {
895     auto rank =
896         mlir::cast<SequenceType>(dyn_cast_ptrOrBoxEleTy(boxTy)).getDimension();
897     auto idxTy = rewriter.getIndexType();
898     for (decltype(rank) dim = 0; dim < rank; ++dim) {
899       auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim);
900       auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy,
901                                                 arrLoad.getMemref(), dimVal);
902       result.emplace_back(dimInfo.getResult(1));
903     }
904     if (!arrLoad.getShape()) {
905       auto shapeType = ShapeType::get(rewriter.getContext(), rank);
906       return rewriter.create<ShapeOp>(loc, shapeType, result);
907     }
908     auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>();
909     auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank);
910     llvm::SmallVector<mlir::Value> shapeShiftOperands;
911     for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) {
912       shapeShiftOperands.push_back(lb);
913       shapeShiftOperands.push_back(extent);
914     }
915     return rewriter.create<ShapeShiftOp>(loc, shapeShiftType,
916                                          shapeShiftOperands);
917   }
918   copyUsingSlice =
919       getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape());
920   return arrLoad.getShape();
921 }
922 
toRefType(mlir::Type ty)923 static mlir::Type toRefType(mlir::Type ty) {
924   if (fir::isa_ref_type(ty))
925     return ty;
926   return fir::ReferenceType::get(ty);
927 }
928 
929 static llvm::SmallVector<mlir::Value>
getTypeParamsIfRawData(mlir::Location loc,FirOpBuilder & builder,ArrayLoadOp arrLoad,mlir::Type ty)930 getTypeParamsIfRawData(mlir::Location loc, FirOpBuilder &builder,
931                        ArrayLoadOp arrLoad, mlir::Type ty) {
932   if (mlir::isa<BoxType>(ty))
933     return {};
934   return fir::factory::getTypeParams(loc, builder, arrLoad);
935 }
936 
genCoorOp(mlir::PatternRewriter & rewriter,mlir::Location loc,mlir::Type eleTy,mlir::Type resTy,mlir::Value alloc,mlir::Value shape,mlir::Value slice,mlir::ValueRange indices,ArrayLoadOp load,bool skipOrig=false)937 static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter,
938                              mlir::Location loc, mlir::Type eleTy,
939                              mlir::Type resTy, mlir::Value alloc,
940                              mlir::Value shape, mlir::Value slice,
941                              mlir::ValueRange indices, ArrayLoadOp load,
942                              bool skipOrig = false) {
943   llvm::SmallVector<mlir::Value> originated;
944   if (skipOrig)
945     originated.assign(indices.begin(), indices.end());
946   else
947     originated = factory::originateIndices(loc, rewriter, alloc.getType(),
948                                            shape, indices);
949   auto seqTy = dyn_cast_ptrOrBoxEleTy(alloc.getType());
950   assert(seqTy && mlir::isa<SequenceType>(seqTy));
951   const auto dimension = mlir::cast<SequenceType>(seqTy).getDimension();
952   auto module = load->getParentOfType<mlir::ModuleOp>();
953   FirOpBuilder builder(rewriter, module);
954   auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType());
955   mlir::Value result = rewriter.create<ArrayCoorOp>(
956       loc, eleTy, alloc, shape, slice,
957       llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
958       typeparams);
959   if (dimension < originated.size())
960     result = rewriter.create<fir::CoordinateOp>(
961         loc, resTy, result,
962         llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
963   return result;
964 }
965 
getCharacterLen(mlir::Location loc,FirOpBuilder & builder,ArrayLoadOp load,CharacterType charTy)966 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder,
967                                    ArrayLoadOp load, CharacterType charTy) {
968   auto charLenTy = builder.getCharacterLengthType();
969   if (charTy.hasDynamicLen()) {
970     if (mlir::isa<BoxType>(load.getMemref().getType())) {
971       // The loaded array is an emboxed value. Get the CHARACTER length from
972       // the box value.
973       auto eleSzInBytes =
974           builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref());
975       auto kindSize =
976           builder.getKindMap().getCharacterBitsize(charTy.getFKind());
977       auto kindByteSize =
978           builder.createIntegerConstant(loc, charLenTy, kindSize / 8);
979       return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes,
980                                                   kindByteSize);
981     }
982     // The loaded array is a (set of) unboxed values. If the CHARACTER's
983     // length is not a constant, it must be provided as a type parameter to
984     // the array_load.
985     auto typeparams = load.getTypeparams();
986     assert(typeparams.size() > 0 && "expected type parameters on array_load");
987     return typeparams.back();
988   }
989   // The typical case: the length of the CHARACTER is a compile-time
990   // constant that is encoded in the type information.
991   return builder.createIntegerConstant(loc, charLenTy, charTy.getLen());
992 }
993 /// Generate a shallow array copy. This is used for both copy-in and copy-out.
994 template <bool CopyIn>
genArrayCopy(mlir::Location loc,mlir::PatternRewriter & rewriter,mlir::Value dst,mlir::Value src,mlir::Value shapeOp,mlir::Value sliceOp,ArrayLoadOp arrLoad)995 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
996                   mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
997                   mlir::Value sliceOp, ArrayLoadOp arrLoad) {
998   auto insPt = rewriter.saveInsertionPoint();
999   llvm::SmallVector<mlir::Value> indices;
1000   llvm::SmallVector<mlir::Value> extents;
1001   bool copyUsingSlice =
1002       getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp);
1003   auto idxTy = rewriter.getIndexType();
1004   // Build loop nest from column to row.
1005   for (auto sh : llvm::reverse(extents)) {
1006     auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh);
1007     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
1008     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
1009     auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one);
1010     auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one);
1011     rewriter.setInsertionPointToStart(loop.getBody());
1012     indices.push_back(loop.getInductionVar());
1013   }
1014   // Reverse the indices so they are in column-major order.
1015   std::reverse(indices.begin(), indices.end());
1016   auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
1017   FirOpBuilder builder(rewriter, module);
1018   auto fromAddr = rewriter.create<ArrayCoorOp>(
1019       loc, getEleTy(src.getType()), src, shapeOp,
1020       CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
1021       factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices),
1022       getTypeParamsIfRawData(loc, builder, arrLoad, src.getType()));
1023   auto toAddr = rewriter.create<ArrayCoorOp>(
1024       loc, getEleTy(dst.getType()), dst, shapeOp,
1025       !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
1026       factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices),
1027       getTypeParamsIfRawData(loc, builder, arrLoad, dst.getType()));
1028   auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType()));
1029   // Copy from (to) object to (from) temp copy of same object.
1030   if (auto charTy = mlir::dyn_cast<CharacterType>(eleTy)) {
1031     auto len = getCharacterLen(loc, builder, arrLoad, charTy);
1032     CharBoxValue toChar(toAddr, len);
1033     CharBoxValue fromChar(fromAddr, len);
1034     factory::genScalarAssignment(builder, loc, toChar, fromChar);
1035   } else {
1036     if (hasDynamicSize(eleTy))
1037       TODO(loc, "copy element of dynamic size");
1038     factory::genScalarAssignment(builder, loc, toAddr, fromAddr);
1039   }
1040   rewriter.restoreInsertionPoint(insPt);
1041 }
1042 
1043 /// The array load may be either a boxed or unboxed value. If the value is
1044 /// boxed, we read the type parameters from the boxed value.
1045 static llvm::SmallVector<mlir::Value>
genArrayLoadTypeParameters(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayLoadOp load)1046 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter,
1047                            ArrayLoadOp load) {
1048   if (load.getTypeparams().empty()) {
1049     auto eleTy =
1050         unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType()));
1051     if (hasDynamicSize(eleTy)) {
1052       if (auto charTy = mlir::dyn_cast<CharacterType>(eleTy)) {
1053         assert(mlir::isa<BoxType>(load.getMemref().getType()));
1054         auto module = load->getParentOfType<mlir::ModuleOp>();
1055         FirOpBuilder builder(rewriter, module);
1056         return {getCharacterLen(loc, builder, load, charTy)};
1057       }
1058       TODO(loc, "unhandled dynamic type parameters");
1059     }
1060     return {};
1061   }
1062   return load.getTypeparams();
1063 }
1064 
1065 static llvm::SmallVector<mlir::Value>
findNonconstantExtents(mlir::Type memrefTy,llvm::ArrayRef<mlir::Value> extents)1066 findNonconstantExtents(mlir::Type memrefTy,
1067                        llvm::ArrayRef<mlir::Value> extents) {
1068   llvm::SmallVector<mlir::Value> nce;
1069   auto arrTy = unwrapPassByRefType(memrefTy);
1070   auto seqTy = mlir::cast<SequenceType>(arrTy);
1071   for (auto [s, x] : llvm::zip(seqTy.getShape(), extents))
1072     if (s == SequenceType::getUnknownExtent())
1073       nce.emplace_back(x);
1074   if (extents.size() > seqTy.getShape().size())
1075     for (auto x : extents.drop_front(seqTy.getShape().size()))
1076       nce.emplace_back(x);
1077   return nce;
1078 }
1079 
1080 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any
1081 /// allocatable direct components of the array elements with an unallocated
1082 /// status. Returns the temporary address as well as a callback to generate the
1083 /// temporary clean-up once it has been used. The clean-up will take care of
1084 /// deallocating all the element allocatable components that may have been
1085 /// allocated while using the temporary.
1086 static std::pair<mlir::Value,
1087                  std::function<void(mlir::PatternRewriter &rewriter)>>
allocateArrayTemp(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayLoadOp load,llvm::ArrayRef<mlir::Value> extents,mlir::Value shape)1088 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter,
1089                   ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents,
1090                   mlir::Value shape) {
1091   mlir::Type baseType = load.getMemref().getType();
1092   llvm::SmallVector<mlir::Value> nonconstantExtents =
1093       findNonconstantExtents(baseType, extents);
1094   llvm::SmallVector<mlir::Value> typeParams =
1095       genArrayLoadTypeParameters(loc, rewriter, load);
1096   mlir::Value allocmem = rewriter.create<AllocMemOp>(
1097       loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents);
1098   mlir::Type eleType =
1099       fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType));
1100   if (fir::isRecordWithAllocatableMember(eleType)) {
1101     // The allocatable component descriptors need to be set to a clean
1102     // deallocated status before anything is done with them.
1103     mlir::Value box = rewriter.create<fir::EmboxOp>(
1104         loc, fir::BoxType::get(allocmem.getType()), allocmem, shape,
1105         /*slice=*/mlir::Value{}, typeParams);
1106     auto module = load->getParentOfType<mlir::ModuleOp>();
1107     FirOpBuilder builder(rewriter, module);
1108     runtime::genDerivedTypeInitialize(builder, loc, box);
1109     // Any allocatable component that may have been allocated must be
1110     // deallocated during the clean-up.
1111     auto cleanup = [=](mlir::PatternRewriter &r) {
1112       FirOpBuilder builder(r, module);
1113       runtime::genDerivedTypeDestroy(builder, loc, box);
1114       r.create<FreeMemOp>(loc, allocmem);
1115     };
1116     return {allocmem, cleanup};
1117   }
1118   auto cleanup = [=](mlir::PatternRewriter &r) {
1119     r.create<FreeMemOp>(loc, allocmem);
1120   };
1121   return {allocmem, cleanup};
1122 }
1123 
1124 namespace {
1125 /// Conversion of fir.array_update and fir.array_modify Ops.
1126 /// If there is a conflict for the update, then we need to perform a
1127 /// copy-in/copy-out to preserve the original values of the array. If there is
1128 /// no conflict, then it is save to eschew making any copies.
1129 template <typename ArrayOp>
1130 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
1131 public:
1132   // TODO: Implement copy/swap semantics?
ArrayUpdateConversionBase(mlir::MLIRContext * ctx,const ArrayCopyAnalysisBase & a,const OperationUseMapT & m)1133   explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
1134                                      const ArrayCopyAnalysisBase &a,
1135                                      const OperationUseMapT &m)
1136       : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
1137 
1138   /// The array_access, \p access, is to be to a cloned copy due to a potential
1139   /// conflict. Uses copy-in/copy-out semantics and not copy/swap.
referenceToClone(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayOp access) const1140   mlir::Value referenceToClone(mlir::Location loc,
1141                                mlir::PatternRewriter &rewriter,
1142                                ArrayOp access) const {
1143     LLVM_DEBUG(llvm::dbgs()
1144                << "generating copy-in/copy-out loops for " << access << '\n');
1145     auto *op = access.getOperation();
1146     auto *loadOp = useMap.lookup(op);
1147     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1148     auto eleTy = access.getType();
1149     rewriter.setInsertionPoint(loadOp);
1150     // Copy in.
1151     llvm::SmallVector<mlir::Value> extents;
1152     bool copyUsingSlice = false;
1153     auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1154                                               copyUsingSlice);
1155     auto [allocmem, genTempCleanUp] =
1156         allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1157     genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1158                                   load.getMemref(), shapeOp, load.getSlice(),
1159                                   load);
1160     // Generate the reference for the access.
1161     rewriter.setInsertionPoint(op);
1162     auto coor = genCoorOp(
1163         rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, shapeOp,
1164         copyUsingSlice ? mlir::Value{} : load.getSlice(), access.getIndices(),
1165         load, access->hasAttr(factory::attrFortranArrayOffsets()));
1166     // Copy out.
1167     auto *storeOp = useMap.lookup(loadOp);
1168     auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1169     rewriter.setInsertionPoint(storeOp);
1170     // Copy out.
1171     genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(),
1172                                    allocmem, shapeOp, store.getSlice(), load);
1173     genTempCleanUp(rewriter);
1174     return coor;
1175   }
1176 
1177   /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
1178   /// temp and the LHS if the analysis found potential overlaps between the RHS
1179   /// and LHS arrays. The element copy generator must be provided in \p
1180   /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
1181   /// Returns the address of the LHS element inside the loop and the LHS
1182   /// ArrayLoad result.
1183   std::pair<mlir::Value, mlir::Value>
materializeAssignment(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayOp update,const std::function<void (mlir::Value)> & assignElement,mlir::Type lhsEltRefType) const1184   materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
1185                         ArrayOp update,
1186                         const std::function<void(mlir::Value)> &assignElement,
1187                         mlir::Type lhsEltRefType) const {
1188     auto *op = update.getOperation();
1189     auto *loadOp = useMap.lookup(op);
1190     auto load = mlir::cast<ArrayLoadOp>(loadOp);
1191     LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
1192     if (analysis.hasPotentialConflict(loadOp)) {
1193       // If there is a conflict between the arrays, then we copy the lhs array
1194       // to a temporary, update the temporary, and copy the temporary back to
1195       // the lhs array. This yields Fortran's copy-in copy-out array semantics.
1196       LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
1197       rewriter.setInsertionPoint(loadOp);
1198       // Copy in.
1199       llvm::SmallVector<mlir::Value> extents;
1200       bool copyUsingSlice = false;
1201       auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1202                                                 copyUsingSlice);
1203       auto [allocmem, genTempCleanUp] =
1204           allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1205 
1206       genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1207                                     load.getMemref(), shapeOp, load.getSlice(),
1208                                     load);
1209       rewriter.setInsertionPoint(op);
1210       auto coor = genCoorOp(
1211           rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
1212           shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
1213           update.getIndices(), load,
1214           update->hasAttr(factory::attrFortranArrayOffsets()));
1215       assignElement(coor);
1216       auto *storeOp = useMap.lookup(loadOp);
1217       auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1218       rewriter.setInsertionPoint(storeOp);
1219       // Copy out.
1220       genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter,
1221                                      store.getMemref(), allocmem, shapeOp,
1222                                      store.getSlice(), load);
1223       genTempCleanUp(rewriter);
1224       return {coor, load.getResult()};
1225     }
1226     // Otherwise, when there is no conflict (a possible loop-carried
1227     // dependence), the lhs array can be updated in place.
1228     LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
1229     rewriter.setInsertionPoint(op);
1230     auto coorTy = getEleTy(load.getType());
1231     auto coor =
1232         genCoorOp(rewriter, loc, coorTy, lhsEltRefType, load.getMemref(),
1233                   load.getShape(), load.getSlice(), update.getIndices(), load,
1234                   update->hasAttr(factory::attrFortranArrayOffsets()));
1235     assignElement(coor);
1236     return {coor, load.getResult()};
1237   }
1238 
1239 protected:
1240   const ArrayCopyAnalysisBase &analysis;
1241   const OperationUseMapT &useMap;
1242 };
1243 
1244 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
1245 public:
ArrayUpdateConversion(mlir::MLIRContext * ctx,const ArrayCopyAnalysisBase & a,const OperationUseMapT & m)1246   explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
1247                                  const ArrayCopyAnalysisBase &a,
1248                                  const OperationUseMapT &m)
1249       : ArrayUpdateConversionBase{ctx, a, m} {}
1250 
1251   llvm::LogicalResult
matchAndRewrite(ArrayUpdateOp update,mlir::PatternRewriter & rewriter) const1252   matchAndRewrite(ArrayUpdateOp update,
1253                   mlir::PatternRewriter &rewriter) const override {
1254     auto loc = update.getLoc();
1255     auto assignElement = [&](mlir::Value coor) {
1256       auto input = update.getMerge();
1257       if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) {
1258         emitFatalError(loc, "array_update on references not supported");
1259       } else {
1260         rewriter.create<fir::StoreOp>(loc, input, coor);
1261       }
1262     };
1263     auto lhsEltRefType = toRefType(update.getMerge().getType());
1264     auto [_, lhsLoadResult] = materializeAssignment(
1265         loc, rewriter, update, assignElement, lhsEltRefType);
1266     update.replaceAllUsesWith(lhsLoadResult);
1267     rewriter.replaceOp(update, lhsLoadResult);
1268     return mlir::success();
1269   }
1270 };
1271 
1272 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
1273 public:
ArrayModifyConversion(mlir::MLIRContext * ctx,const ArrayCopyAnalysisBase & a,const OperationUseMapT & m)1274   explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
1275                                  const ArrayCopyAnalysisBase &a,
1276                                  const OperationUseMapT &m)
1277       : ArrayUpdateConversionBase{ctx, a, m} {}
1278 
1279   llvm::LogicalResult
matchAndRewrite(ArrayModifyOp modify,mlir::PatternRewriter & rewriter) const1280   matchAndRewrite(ArrayModifyOp modify,
1281                   mlir::PatternRewriter &rewriter) const override {
1282     auto loc = modify.getLoc();
1283     auto assignElement = [](mlir::Value) {
1284       // Assignment already materialized by lowering using lhs element address.
1285     };
1286     auto lhsEltRefType = modify.getResult(0).getType();
1287     auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
1288         loc, rewriter, modify, assignElement, lhsEltRefType);
1289     modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1290     rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1291     return mlir::success();
1292   }
1293 };
1294 
1295 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
1296 public:
ArrayFetchConversion(mlir::MLIRContext * ctx,const OperationUseMapT & m)1297   explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
1298                                 const OperationUseMapT &m)
1299       : OpRewritePattern{ctx}, useMap{m} {}
1300 
1301   llvm::LogicalResult
matchAndRewrite(ArrayFetchOp fetch,mlir::PatternRewriter & rewriter) const1302   matchAndRewrite(ArrayFetchOp fetch,
1303                   mlir::PatternRewriter &rewriter) const override {
1304     auto *op = fetch.getOperation();
1305     rewriter.setInsertionPoint(op);
1306     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1307     auto loc = fetch.getLoc();
1308     auto coor = genCoorOp(
1309         rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()),
1310         load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(),
1311         load, fetch->hasAttr(factory::attrFortranArrayOffsets()));
1312     if (isa_ref_type(fetch.getType()))
1313       rewriter.replaceOp(fetch, coor);
1314     else
1315       rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
1316     return mlir::success();
1317   }
1318 
1319 private:
1320   const OperationUseMapT &useMap;
1321 };
1322 
1323 /// As array_access op is like an array_fetch op, except that it does not imply
1324 /// a load op. (It operates in the reference domain.)
1325 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> {
1326 public:
ArrayAccessConversion(mlir::MLIRContext * ctx,const ArrayCopyAnalysisBase & a,const OperationUseMapT & m)1327   explicit ArrayAccessConversion(mlir::MLIRContext *ctx,
1328                                  const ArrayCopyAnalysisBase &a,
1329                                  const OperationUseMapT &m)
1330       : ArrayUpdateConversionBase{ctx, a, m} {}
1331 
1332   llvm::LogicalResult
matchAndRewrite(ArrayAccessOp access,mlir::PatternRewriter & rewriter) const1333   matchAndRewrite(ArrayAccessOp access,
1334                   mlir::PatternRewriter &rewriter) const override {
1335     auto *op = access.getOperation();
1336     auto loc = access.getLoc();
1337     if (analysis.inAmendAccessSet(op)) {
1338       // This array_access is associated with an array_amend and there is a
1339       // conflict. Make a copy to store into.
1340       auto result = referenceToClone(loc, rewriter, access);
1341       access.replaceAllUsesWith(result);
1342       rewriter.replaceOp(access, result);
1343       return mlir::success();
1344     }
1345     rewriter.setInsertionPoint(op);
1346     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1347     auto coor = genCoorOp(
1348         rewriter, loc, getEleTy(load.getType()), toRefType(access.getType()),
1349         load.getMemref(), load.getShape(), load.getSlice(), access.getIndices(),
1350         load, access->hasAttr(factory::attrFortranArrayOffsets()));
1351     rewriter.replaceOp(access, coor);
1352     return mlir::success();
1353   }
1354 };
1355 
1356 /// An array_amend op is a marker to record which array access is being used to
1357 /// update an array value. After this pass runs, an array_amend has no
1358 /// semantics. We rewrite these to undefined values here to remove them while
1359 /// preserving SSA form.
1360 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> {
1361 public:
ArrayAmendConversion(mlir::MLIRContext * ctx)1362   explicit ArrayAmendConversion(mlir::MLIRContext *ctx)
1363       : OpRewritePattern{ctx} {}
1364 
1365   llvm::LogicalResult
matchAndRewrite(ArrayAmendOp amend,mlir::PatternRewriter & rewriter) const1366   matchAndRewrite(ArrayAmendOp amend,
1367                   mlir::PatternRewriter &rewriter) const override {
1368     auto *op = amend.getOperation();
1369     rewriter.setInsertionPoint(op);
1370     auto loc = amend.getLoc();
1371     auto undef = rewriter.create<UndefOp>(loc, amend.getType());
1372     rewriter.replaceOp(amend, undef.getResult());
1373     return mlir::success();
1374   }
1375 };
1376 
1377 class ArrayValueCopyConverter
1378     : public fir::impl::ArrayValueCopyBase<ArrayValueCopyConverter> {
1379 public:
1380   ArrayValueCopyConverter() = default;
ArrayValueCopyConverter(const fir::ArrayValueCopyOptions & options)1381   ArrayValueCopyConverter(const fir::ArrayValueCopyOptions &options)
1382       : Base(options) {}
1383 
runOnOperation()1384   void runOnOperation() override {
1385     auto func = getOperation();
1386     LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
1387                             << func.getName() << "'\n");
1388     auto *context = &getContext();
1389 
1390     // Perform the conflict analysis.
1391     const ArrayCopyAnalysisBase *analysis;
1392     if (optimizeConflicts)
1393       analysis = &getAnalysis<ArrayCopyAnalysisOptimized>();
1394     else
1395       analysis = &getAnalysis<ArrayCopyAnalysis>();
1396 
1397     const auto &useMap = analysis->getUseMap();
1398 
1399     mlir::RewritePatternSet patterns1(context);
1400     patterns1.insert<ArrayFetchConversion>(context, useMap);
1401     patterns1.insert<ArrayUpdateConversion>(context, *analysis, useMap);
1402     patterns1.insert<ArrayModifyConversion>(context, *analysis, useMap);
1403     patterns1.insert<ArrayAccessConversion>(context, *analysis, useMap);
1404     patterns1.insert<ArrayAmendConversion>(context);
1405     mlir::ConversionTarget target(*context);
1406     target
1407         .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
1408                          mlir::arith::ArithDialect, mlir::func::FuncDialect>();
1409     target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp,
1410                         ArrayUpdateOp, ArrayModifyOp>();
1411     // Rewrite the array fetch and array update ops.
1412     if (mlir::failed(
1413             mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
1414       mlir::emitError(mlir::UnknownLoc::get(context),
1415                       "failure in array-value-copy pass, phase 1");
1416       signalPassFailure();
1417     }
1418 
1419     mlir::RewritePatternSet patterns2(context);
1420     patterns2.insert<ArrayLoadConversion>(context);
1421     patterns2.insert<ArrayMergeStoreConversion>(context);
1422     target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
1423     if (mlir::failed(
1424             mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
1425       mlir::emitError(mlir::UnknownLoc::get(context),
1426                       "failure in array-value-copy pass, phase 2");
1427       signalPassFailure();
1428     }
1429   }
1430 };
1431 } // namespace
1432 
1433 std::unique_ptr<mlir::Pass>
createArrayValueCopyPass(fir::ArrayValueCopyOptions options)1434 fir::createArrayValueCopyPass(fir::ArrayValueCopyOptions options) {
1435   return std::make_unique<ArrayValueCopyConverter>(options);
1436 }
1437