xref: /llvm-project/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp (revision 092601d4baab7c13c06b31eda2d5bed91d9a6b65)
1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/BoxValue.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/Factory.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "flang-array-value-copy"
22 
23 using namespace fir;
24 using namespace mlir;
25 
26 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
27 
28 namespace {
29 
30 /// Array copy analysis.
31 /// Perform an interference analysis between array values.
32 ///
33 /// Lowering will generate a sequence of the following form.
34 /// ```mlir
35 ///   %a_1 = fir.array_load %array_1(%shape) : ...
36 ///   ...
37 ///   %a_j = fir.array_load %array_j(%shape) : ...
38 ///   ...
39 ///   %a_n = fir.array_load %array_n(%shape) : ...
40 ///     ...
41 ///     %v_i = fir.array_fetch %a_i, ...
42 ///     %a_j1 = fir.array_update %a_j, ...
43 ///     ...
44 ///   fir.array_merge_store %a_j, %a_jn to %array_j : ...
45 /// ```
46 ///
47 /// The analysis is to determine if there are any conflicts. A conflict is when
48 /// one the following cases occurs.
49 ///
50 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
51 /// loaded from the same array memory reference (array_j) but with a different
52 /// shape as the other array values a_i, where i != j. [Possible overlapping
53 /// arrays.]
54 ///
55 /// 2. There is either an array_fetch or array_update of a_j with a different
56 /// set of index values. [Possible loop-carried dependence.]
57 ///
58 /// If none of the array values overlap in storage and the accesses are not
59 /// loop-carried, then the arrays are conflict-free and no copies are required.
60 class ArrayCopyAnalysis {
61 public:
62   using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
63   using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
64   using LoadMapSetsT =
65       llvm::DenseMap<mlir::Operation *, SmallVector<Operation *>>;
66 
67   ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); }
68 
69   mlir::Operation *getOperation() const { return operation; }
70 
71   /// Return true iff the `array_merge_store` has potential conflicts.
72   bool hasPotentialConflict(mlir::Operation *op) const {
73     LLVM_DEBUG(llvm::dbgs()
74                << "looking for a conflict on " << *op
75                << " and the set has a total of " << conflicts.size() << '\n');
76     return conflicts.contains(op);
77   }
78 
79   /// Return the use map. The use map maps array fetch and update operations
80   /// back to the array load that is the original source of the array value.
81   const OperationUseMapT &getUseMap() const { return useMap; }
82 
83   /// Find all the array operations that access the array value that is loaded
84   /// by the array load operation, `load`.
85   const llvm::SmallVector<mlir::Operation *> &arrayAccesses(ArrayLoadOp load);
86 
87 private:
88   void construct(mlir::Operation *topLevelOp);
89 
90   mlir::Operation *operation; // operation that analysis ran upon
91   ConflictSetT conflicts;     // set of conflicts (loads and merge stores)
92   OperationUseMapT useMap;
93   LoadMapSetsT loadMapSets;
94 };
95 } // namespace
96 
97 namespace {
98 /// Helper class to collect all array operations that produced an array value.
99 class ReachCollector {
100 private:
101   // If provided, the `loopRegion` is the body of a loop that produces the array
102   // of interest.
103   ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
104                  mlir::Region *loopRegion)
105       : reach{reach}, loopRegion{loopRegion} {}
106 
107   void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) {
108     llvm::errs() << "COLLECT " << *op << "\n";
109     if (range.empty()) {
110       collectArrayAccessFrom(op, mlir::Value{});
111       return;
112     }
113     for (mlir::Value v : range)
114       collectArrayAccessFrom(v);
115   }
116 
117   // TODO: Replace recursive algorithm on def-use chain with an iterative one
118   // with an explicit stack.
119   void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) {
120     // `val` is defined by an Op, process the defining Op.
121     // If `val` is defined by a region containing Op, we want to drill down
122     // and through that Op's region(s).
123     llvm::errs() << "COLLECT " << *op << "\n";
124     LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
125     auto popFn = [&](auto rop) {
126       assert(val && "op must have a result value");
127       auto resNum = val.cast<mlir::OpResult>().getResultNumber();
128       llvm::SmallVector<mlir::Value> results;
129       rop.resultToSourceOps(results, resNum);
130       for (auto u : results)
131         collectArrayAccessFrom(u);
132     };
133     if (auto rop = mlir::dyn_cast<fir::DoLoopOp>(op)) {
134       popFn(rop);
135       return;
136     }
137     if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
138       popFn(rop);
139       return;
140     }
141     if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
142       if (opIsInsideLoops(mergeStore))
143         collectArrayAccessFrom(mergeStore.getSequence());
144       return;
145     }
146 
147     if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
148       // Look for any stores inside the loops, and collect an array operation
149       // that produced the value being stored to it.
150       for (mlir::Operation *user : op->getUsers())
151         if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
152           if (opIsInsideLoops(store))
153             collectArrayAccessFrom(store.getValue());
154       return;
155     }
156 
157     // Otherwise, Op does not contain a region so just chase its operands.
158     if (mlir::isa<ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, ArrayFetchOp>(
159             op)) {
160       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
161       reach.emplace_back(op);
162     }
163     // Array modify assignment is performed on the result. So the analysis
164     // must look at the what is done with the result.
165     if (mlir::isa<ArrayModifyOp>(op))
166       for (mlir::Operation *user : op->getResult(0).getUsers())
167         followUsers(user);
168 
169     for (auto u : op->getOperands())
170       collectArrayAccessFrom(u);
171   }
172 
173   void collectArrayAccessFrom(mlir::BlockArgument ba) {
174     auto *parent = ba.getOwner()->getParentOp();
175     // If inside an Op holding a region, the block argument corresponds to an
176     // argument passed to the containing Op.
177     auto popFn = [&](auto rop) {
178       collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
179     };
180     if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
181       popFn(rop);
182       return;
183     }
184     if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
185       popFn(rop);
186       return;
187     }
188     // Otherwise, a block argument is provided via the pred blocks.
189     for (auto *pred : ba.getOwner()->getPredecessors()) {
190       auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
191       collectArrayAccessFrom(u);
192     }
193   }
194 
195   // Recursively trace operands to find all array operations relating to the
196   // values merged.
197   void collectArrayAccessFrom(mlir::Value val) {
198     if (!val || visited.contains(val))
199       return;
200     visited.insert(val);
201 
202     // Process a block argument.
203     if (auto ba = val.dyn_cast<mlir::BlockArgument>()) {
204       collectArrayAccessFrom(ba);
205       return;
206     }
207 
208     // Process an Op.
209     if (auto *op = val.getDefiningOp()) {
210       collectArrayAccessFrom(op, val);
211       return;
212     }
213 
214     fir::emitFatalError(val.getLoc(), "unhandled value");
215   }
216 
217   /// Is \op inside the loop nest region ?
218   bool opIsInsideLoops(mlir::Operation *op) const {
219     return loopRegion && loopRegion->isAncestor(op->getParentRegion());
220   }
221 
222   /// Recursively trace the use of an operation results, calling
223   /// collectArrayAccessFrom on the direct and indirect user operands.
224   /// TODO: Replace recursive algorithm on def-use chain with an iterative one
225   /// with an explicit stack.
226   void followUsers(mlir::Operation *op) {
227     for (auto userOperand : op->getOperands())
228       collectArrayAccessFrom(userOperand);
229     // Go through potential converts/coordinate_op.
230     for (mlir::Operation *indirectUser : op->getUsers())
231       followUsers(indirectUser);
232   }
233 
234   llvm::SmallVectorImpl<mlir::Operation *> &reach;
235   llvm::SmallPtrSet<mlir::Value, 16> visited;
236   /// Region of the loops nest that produced the array value.
237   mlir::Region *loopRegion;
238 
239 public:
240   /// Return all ops that produce the array value that is stored into the
241   /// `array_merge_store`.
242   static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
243                              mlir::Value seq) {
244     reach.clear();
245     mlir::Region *loopRegion = nullptr;
246     // Only `DoLoopOp` is tested here since array operations are currently only
247     // associated with this kind of loop.
248     if (auto doLoop =
249             mlir::dyn_cast_or_null<fir::DoLoopOp>(seq.getDefiningOp()))
250       loopRegion = &doLoop->getRegion(0);
251     ReachCollector collector(reach, loopRegion);
252     collector.collectArrayAccessFrom(seq);
253   }
254 };
255 } // namespace
256 
257 /// Find all the array operations that access the array value that is loaded by
258 /// the array load operation, `load`.
259 const llvm::SmallVector<mlir::Operation *> &
260 ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) {
261   auto lmIter = loadMapSets.find(load);
262   if (lmIter != loadMapSets.end())
263     return lmIter->getSecond();
264 
265   llvm::SmallVector<mlir::Operation *> accesses;
266   UseSetT visited;
267   llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
268 
269   auto appendToQueue = [&](mlir::Value val) {
270     for (mlir::OpOperand &use : val.getUses())
271       if (!visited.count(&use)) {
272         visited.insert(&use);
273         queue.push_back(&use);
274       }
275   };
276 
277   // Build the set of uses of `original`.
278   // let USES = { uses of original fir.load }
279   appendToQueue(load);
280 
281   // Process the worklist until done.
282   while (!queue.empty()) {
283     mlir::OpOperand *operand = queue.pop_back_val();
284     mlir::Operation *owner = operand->getOwner();
285 
286     auto structuredLoop = [&](auto ro) {
287       if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
288         int64_t arg = blockArg.getArgNumber();
289         mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1);
290         appendToQueue(output);
291         appendToQueue(blockArg);
292       }
293     };
294     // TODO: this need to be updated to use the control-flow interface.
295     auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
296       if (operands.empty())
297         return;
298 
299       // Check if this operand is within the range.
300       unsigned operandIndex = operand->getOperandNumber();
301       unsigned operandsStart = operands.getBeginOperandIndex();
302       if (operandIndex < operandsStart ||
303           operandIndex >= (operandsStart + operands.size()))
304         return;
305 
306       // Index the successor.
307       unsigned argIndex = operandIndex - operandsStart;
308       appendToQueue(dest->getArgument(argIndex));
309     };
310     // Thread uses into structured loop bodies and return value uses.
311     if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
312       structuredLoop(ro);
313     } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
314       structuredLoop(ro);
315     } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
316       // Thread any uses of fir.if that return the marked array value.
317       if (auto ifOp = rs->getParentOfType<fir::IfOp>())
318         appendToQueue(ifOp.getResult(operand->getOperandNumber()));
319     } else if (mlir::isa<ArrayFetchOp>(owner)) {
320       // Keep track of array value fetches.
321       LLVM_DEBUG(llvm::dbgs()
322                  << "add fetch {" << *owner << "} to array value set\n");
323       accesses.push_back(owner);
324     } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
325       // Keep track of array value updates and thread the return value uses.
326       LLVM_DEBUG(llvm::dbgs()
327                  << "add update {" << *owner << "} to array value set\n");
328       accesses.push_back(owner);
329       appendToQueue(update.getResult());
330     } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
331       // Keep track of array value modification and thread the return value
332       // uses.
333       LLVM_DEBUG(llvm::dbgs()
334                  << "add modify {" << *owner << "} to array value set\n");
335       accesses.push_back(owner);
336       appendToQueue(update.getResult(1));
337     } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
338       branchOp(br.getDest(), br.getDestOperands());
339     } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
340       branchOp(br.getTrueDest(), br.getTrueOperands());
341       branchOp(br.getFalseDest(), br.getFalseOperands());
342     } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
343       // do nothing
344     } else {
345       llvm::report_fatal_error("array value reached unexpected op");
346     }
347   }
348   return loadMapSets.insert({load, accesses}).first->getSecond();
349 }
350 
351 /// Is there a conflict between the array value that was updated and to be
352 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
353 /// the updated value?
354 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
355                            ArrayMergeStoreOp st) {
356   mlir::Value load;
357   mlir::Value addr = st.getMemref();
358   auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType());
359   for (auto *op : reach) {
360     auto ld = mlir::dyn_cast<ArrayLoadOp>(op);
361     if (!ld)
362       continue;
363     mlir::Type ldTy = ld.getMemref().getType();
364     if (auto boxTy = ldTy.dyn_cast<fir::BoxType>())
365       ldTy = boxTy.getEleTy();
366     if (ldTy.isa<fir::PointerType>() && stEleTy == dyn_cast_ptrEleTy(ldTy))
367       return true;
368     if (ld.getMemref() == addr) {
369       if (ld.getResult() != st.getOriginal())
370         return true;
371       if (load)
372         return true;
373       load = ld;
374     }
375   }
376   return false;
377 }
378 
379 /// Check if there is any potential conflict in the chained update operations
380 /// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the
381 /// array. A potential conflict is detected if two operations work on the same
382 /// indices.
383 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> accesses) {
384   if (accesses.size() < 2)
385     return false;
386   llvm::SmallVector<mlir::Value> indices;
387   LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size()
388                           << " accesses on the list\n");
389   for (auto *op : accesses) {
390     assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) &&
391            "unexpected operation in analysis");
392     llvm::SmallVector<mlir::Value> compareVector;
393     if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
394       if (indices.empty()) {
395         indices = u.getIndices();
396         continue;
397       }
398       compareVector = u.getIndices();
399     } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
400       if (indices.empty()) {
401         indices = f.getIndices();
402         continue;
403       }
404       compareVector = f.getIndices();
405     } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
406       if (indices.empty()) {
407         indices = f.getIndices();
408         continue;
409       }
410       compareVector = f.getIndices();
411     }
412     if (compareVector != indices)
413       return true;
414     LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
415   }
416   return false;
417 }
418 
419 // Are either of types of conflicts present?
420 inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
421                              llvm::ArrayRef<mlir::Operation *> accesses,
422                              ArrayMergeStoreOp st) {
423   return conflictOnLoad(reach, st) || conflictOnMerge(accesses);
424 }
425 
426 /// Constructor of the array copy analysis.
427 /// This performs the analysis and saves the intermediate results.
428 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) {
429   topLevelOp->walk([&](Operation *op) {
430     if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
431       llvm::SmallVector<Operation *> values;
432       ReachCollector::reachingValues(values, st.getSequence());
433       const llvm::SmallVector<Operation *> &accesses = arrayAccesses(
434           mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp()));
435       if (conflictDetected(values, accesses, st)) {
436         LLVM_DEBUG(llvm::dbgs()
437                    << "CONFLICT: copies required for " << st << '\n'
438                    << "   adding conflicts on: " << op << " and "
439                    << st.getOriginal() << '\n');
440         conflicts.insert(op);
441         conflicts.insert(st.getOriginal().getDefiningOp());
442       }
443       auto *ld = st.getOriginal().getDefiningOp();
444       LLVM_DEBUG(llvm::dbgs()
445                  << "map: adding {" << *ld << " -> " << st << "}\n");
446       useMap.insert({ld, op});
447     } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
448       const llvm::SmallVector<mlir::Operation *> &accesses =
449           arrayAccesses(load);
450       LLVM_DEBUG(llvm::dbgs() << "process load: " << load
451                               << ", accesses: " << accesses.size() << '\n');
452       for (auto *acc : accesses) {
453         LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n');
454         assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(acc)));
455         if (!useMap.insert({acc, op}).second) {
456           mlir::emitError(
457               load.getLoc(),
458               "The parallel semantics of multiple array_merge_stores per "
459               "array_load are not supported.");
460           return;
461         }
462         LLVM_DEBUG(llvm::dbgs()
463                    << "map: adding {" << *acc << "} -> {" << load << "}\n");
464       }
465     }
466   });
467 }
468 
469 namespace {
470 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
471 public:
472   using OpRewritePattern::OpRewritePattern;
473 
474   mlir::LogicalResult
475   matchAndRewrite(ArrayLoadOp load,
476                   mlir::PatternRewriter &rewriter) const override {
477     LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
478     rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
479     return mlir::success();
480   }
481 };
482 
483 class ArrayMergeStoreConversion
484     : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
485 public:
486   using OpRewritePattern::OpRewritePattern;
487 
488   mlir::LogicalResult
489   matchAndRewrite(ArrayMergeStoreOp store,
490                   mlir::PatternRewriter &rewriter) const override {
491     LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
492     rewriter.eraseOp(store);
493     return mlir::success();
494   }
495 };
496 } // namespace
497 
498 static mlir::Type getEleTy(mlir::Type ty) {
499   if (auto t = dyn_cast_ptrEleTy(ty))
500     ty = t;
501   if (auto t = ty.dyn_cast<SequenceType>())
502     ty = t.getEleTy();
503   // FIXME: keep ptr/heap/ref information.
504   return ReferenceType::get(ty);
505 }
506 
507 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
508 // TODO: getExtents on op should return a ValueRange instead of a vector.
509 static void getExtents(llvm::SmallVectorImpl<mlir::Value> &result,
510                        mlir::Value shape) {
511   auto *shapeOp = shape.getDefiningOp();
512   if (auto s = mlir::dyn_cast<fir::ShapeOp>(shapeOp)) {
513     auto e = s.getExtents();
514     result.insert(result.end(), e.begin(), e.end());
515     return;
516   }
517   if (auto s = mlir::dyn_cast<fir::ShapeShiftOp>(shapeOp)) {
518     auto e = s.getExtents();
519     result.insert(result.end(), e.begin(), e.end());
520     return;
521   }
522   llvm::report_fatal_error("not a fir.shape/fir.shape_shift op");
523 }
524 
525 // Place the extents of the array loaded by an ArrayLoadOp into the result
526 // vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If
527 // the ArrayLoadOp is loading a fir.box, code will be generated to read the
528 // extents from the fir.box, and a the retunred ShapeOp is built with the read
529 // extents.
530 // Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp
531 // argument of the ArrayLoadOp that is returned.
532 static mlir::Value
533 getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter,
534                            fir::ArrayLoadOp loadOp,
535                            llvm::SmallVectorImpl<mlir::Value> &result) {
536   assert(result.empty());
537   if (auto boxTy = loadOp.getMemref().getType().dyn_cast<fir::BoxType>()) {
538     auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy)
539                     .cast<fir::SequenceType>()
540                     .getDimension();
541     auto idxTy = rewriter.getIndexType();
542     for (decltype(rank) dim = 0; dim < rank; ++dim) {
543       auto dimVal = rewriter.create<arith::ConstantIndexOp>(loc, dim);
544       auto dimInfo = rewriter.create<fir::BoxDimsOp>(
545           loc, idxTy, idxTy, idxTy, loadOp.getMemref(), dimVal);
546       result.emplace_back(dimInfo.getResult(1));
547     }
548     auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank);
549     return rewriter.create<fir::ShapeOp>(loc, shapeType, result);
550   }
551   getExtents(result, loadOp.getShape());
552   return loadOp.getShape();
553 }
554 
555 static mlir::Type toRefType(mlir::Type ty) {
556   if (fir::isa_ref_type(ty))
557     return ty;
558   return fir::ReferenceType::get(ty);
559 }
560 
561 static mlir::Value
562 genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy,
563           mlir::Type resTy, mlir::Value alloc, mlir::Value shape,
564           mlir::Value slice, mlir::ValueRange indices,
565           mlir::ValueRange typeparams, bool skipOrig = false) {
566   llvm::SmallVector<mlir::Value> originated;
567   if (skipOrig)
568     originated.assign(indices.begin(), indices.end());
569   else
570     originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(),
571                                                 shape, indices);
572   auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType());
573   assert(seqTy && seqTy.isa<fir::SequenceType>());
574   const auto dimension = seqTy.cast<fir::SequenceType>().getDimension();
575   mlir::Value result = rewriter.create<fir::ArrayCoorOp>(
576       loc, eleTy, alloc, shape, slice,
577       llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
578       typeparams);
579   if (dimension < originated.size())
580     result = rewriter.create<fir::CoordinateOp>(
581         loc, resTy, result,
582         llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
583   return result;
584 }
585 
586 namespace {
587 /// Conversion of fir.array_update and fir.array_modify Ops.
588 /// If there is a conflict for the update, then we need to perform a
589 /// copy-in/copy-out to preserve the original values of the array. If there is
590 /// no conflict, then it is save to eschew making any copies.
591 template <typename ArrayOp>
592 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
593 public:
594   explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
595                                      const ArrayCopyAnalysis &a,
596                                      const OperationUseMapT &m)
597       : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
598 
599   void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
600                     mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
601                     mlir::Type arrTy) const {
602     auto insPt = rewriter.saveInsertionPoint();
603     llvm::SmallVector<mlir::Value> indices;
604     llvm::SmallVector<mlir::Value> extents;
605     getExtents(extents, shapeOp);
606     // Build loop nest from column to row.
607     for (auto sh : llvm::reverse(extents)) {
608       auto idxTy = rewriter.getIndexType();
609       auto ubi = rewriter.create<fir::ConvertOp>(loc, idxTy, sh);
610       auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
611       auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
612       auto ub = rewriter.create<arith::SubIOp>(loc, idxTy, ubi, one);
613       auto loop = rewriter.create<fir::DoLoopOp>(loc, zero, ub, one);
614       rewriter.setInsertionPointToStart(loop.getBody());
615       indices.push_back(loop.getInductionVar());
616     }
617     // Reverse the indices so they are in column-major order.
618     std::reverse(indices.begin(), indices.end());
619     auto ty = getEleTy(arrTy);
620     auto fromAddr = rewriter.create<fir::ArrayCoorOp>(
621         loc, ty, src, shapeOp, mlir::Value{},
622         fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp,
623                                        indices),
624         mlir::ValueRange{});
625     auto load = rewriter.create<fir::LoadOp>(loc, fromAddr);
626     auto toAddr = rewriter.create<fir::ArrayCoorOp>(
627         loc, ty, dst, shapeOp, mlir::Value{},
628         fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp,
629                                        indices),
630         mlir::ValueRange{});
631     rewriter.create<fir::StoreOp>(loc, load, toAddr);
632     rewriter.restoreInsertionPoint(insPt);
633   }
634 
635   /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
636   /// temp and the LHS if the analysis found potential overlaps between the RHS
637   /// and LHS arrays. The element copy generator must be provided through \p
638   /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
639   /// Returns the address of the LHS element inside the loop and the LHS
640   /// ArrayLoad result.
641   std::pair<mlir::Value, mlir::Value>
642   materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
643                         ArrayOp update,
644                         llvm::function_ref<void(mlir::Value)> assignElement,
645                         mlir::Type lhsEltRefType) const {
646     auto *op = update.getOperation();
647     mlir::Operation *loadOp = useMap.lookup(op);
648     auto load = mlir::cast<ArrayLoadOp>(loadOp);
649     LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
650     if (analysis.hasPotentialConflict(loadOp)) {
651       // If there is a conflict between the arrays, then we copy the lhs array
652       // to a temporary, update the temporary, and copy the temporary back to
653       // the lhs array. This yields Fortran's copy-in copy-out array semantics.
654       LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
655       rewriter.setInsertionPoint(loadOp);
656       // Copy in.
657       llvm::SmallVector<mlir::Value> extents;
658       mlir::Value shapeOp =
659           getOrReadExtentsAndShapeOp(loc, rewriter, load, extents);
660       auto allocmem = rewriter.create<AllocMemOp>(
661           loc, dyn_cast_ptrOrBoxEleTy(load.getMemref().getType()),
662           load.getTypeparams(), extents);
663       genArrayCopy(load.getLoc(), rewriter, allocmem, load.getMemref(), shapeOp,
664                    load.getType());
665       rewriter.setInsertionPoint(op);
666       mlir::Value coor = genCoorOp(
667           rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
668           shapeOp, load.getSlice(), update.getIndices(), load.getTypeparams(),
669           update->hasAttr(fir::factory::attrFortranArrayOffsets()));
670       assignElement(coor);
671       mlir::Operation *storeOp = useMap.lookup(loadOp);
672       auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
673       rewriter.setInsertionPoint(storeOp);
674       // Copy out.
675       genArrayCopy(store.getLoc(), rewriter, store.getMemref(), allocmem,
676                    shapeOp, load.getType());
677       rewriter.create<FreeMemOp>(loc, allocmem);
678       return {coor, load.getResult()};
679     }
680     // Otherwise, when there is no conflict (a possible loop-carried
681     // dependence), the lhs array can be updated in place.
682     LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
683     rewriter.setInsertionPoint(op);
684     auto coorTy = getEleTy(load.getType());
685     mlir::Value coor = genCoorOp(
686         rewriter, loc, coorTy, lhsEltRefType, load.getMemref(), load.getShape(),
687         load.getSlice(), update.getIndices(), load.getTypeparams(),
688         update->hasAttr(fir::factory::attrFortranArrayOffsets()));
689     assignElement(coor);
690     return {coor, load.getResult()};
691   }
692 
693 private:
694   const ArrayCopyAnalysis &analysis;
695   const OperationUseMapT &useMap;
696 };
697 
698 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
699 public:
700   explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
701                                  const ArrayCopyAnalysis &a,
702                                  const OperationUseMapT &m)
703       : ArrayUpdateConversionBase{ctx, a, m} {}
704 
705   mlir::LogicalResult
706   matchAndRewrite(ArrayUpdateOp update,
707                   mlir::PatternRewriter &rewriter) const override {
708     auto loc = update.getLoc();
709     auto assignElement = [&](mlir::Value coor) {
710       rewriter.create<fir::StoreOp>(loc, update.getMerge(), coor);
711     };
712     auto lhsEltRefType = toRefType(update.getMerge().getType());
713     auto [_, lhsLoadResult] = materializeAssignment(
714         loc, rewriter, update, assignElement, lhsEltRefType);
715     update.replaceAllUsesWith(lhsLoadResult);
716     rewriter.replaceOp(update, lhsLoadResult);
717     return mlir::success();
718   }
719 };
720 
721 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
722 public:
723   explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
724                                  const ArrayCopyAnalysis &a,
725                                  const OperationUseMapT &m)
726       : ArrayUpdateConversionBase{ctx, a, m} {}
727 
728   mlir::LogicalResult
729   matchAndRewrite(ArrayModifyOp modify,
730                   mlir::PatternRewriter &rewriter) const override {
731     auto loc = modify.getLoc();
732     auto assignElement = [](mlir::Value) {
733       // Assignment already materialized by lowering using lhs element address.
734     };
735     auto lhsEltRefType = modify.getResult(0).getType();
736     auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
737         loc, rewriter, modify, assignElement, lhsEltRefType);
738     modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
739     rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
740     return mlir::success();
741   }
742 };
743 
744 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
745 public:
746   explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
747                                 const OperationUseMapT &m)
748       : OpRewritePattern{ctx}, useMap{m} {}
749 
750   mlir::LogicalResult
751   matchAndRewrite(ArrayFetchOp fetch,
752                   mlir::PatternRewriter &rewriter) const override {
753     auto *op = fetch.getOperation();
754     rewriter.setInsertionPoint(op);
755     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
756     auto loc = fetch.getLoc();
757     mlir::Value coor =
758         genCoorOp(rewriter, loc, getEleTy(load.getType()),
759                   toRefType(fetch.getType()), load.getMemref(), load.getShape(),
760                   load.getSlice(), fetch.getIndices(), load.getTypeparams(),
761                   fetch->hasAttr(fir::factory::attrFortranArrayOffsets()));
762     rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
763     return mlir::success();
764   }
765 
766 private:
767   const OperationUseMapT &useMap;
768 };
769 } // namespace
770 
771 namespace {
772 class ArrayValueCopyConverter
773     : public ArrayValueCopyBase<ArrayValueCopyConverter> {
774 public:
775   void runOnOperation() override {
776     auto func = getOperation();
777     LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
778                             << func.getName() << "'\n");
779     auto *context = &getContext();
780 
781     // Perform the conflict analysis.
782     auto &analysis = getAnalysis<ArrayCopyAnalysis>();
783     const auto &useMap = analysis.getUseMap();
784 
785     // Phase 1 is performing a rewrite on the array accesses. Once all the
786     // array accesses are rewritten we can go on phase 2.
787     // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in
788     // /copy-out refers the Fortran copy-in/copy-out semantics on statements.
789     mlir::RewritePatternSet patterns1(context);
790     patterns1.insert<ArrayFetchConversion>(context, useMap);
791     patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
792     patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
793     mlir::ConversionTarget target(*context);
794     target.addLegalDialect<
795         FIROpsDialect, mlir::scf::SCFDialect, mlir::arith::ArithmeticDialect,
796         mlir::cf::ControlFlowDialect, mlir::func::FuncDialect>();
797     target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>();
798     // Rewrite the array fetch and array update ops.
799     if (mlir::failed(
800             mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
801       mlir::emitError(mlir::UnknownLoc::get(context),
802                       "failure in array-value-copy pass, phase 1");
803       signalPassFailure();
804     }
805 
806     mlir::RewritePatternSet patterns2(context);
807     patterns2.insert<ArrayLoadConversion>(context);
808     patterns2.insert<ArrayMergeStoreConversion>(context);
809     target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
810     if (mlir::failed(
811             mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
812       mlir::emitError(mlir::UnknownLoc::get(context),
813                       "failure in array-value-copy pass, phase 2");
814       signalPassFailure();
815     }
816   }
817 };
818 } // namespace
819 
820 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() {
821   return std::make_unique<ArrayValueCopyConverter>();
822 }
823