xref: /llvm-project/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp (revision 196c4279c08d1ef72c2e0196ec69f4a0f8a3a87e)
1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/BoxValue.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/Factory.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "flang-array-value-copy"
21 
22 using namespace fir;
23 
24 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
25 
26 namespace {
27 
28 /// Array copy analysis.
29 /// Perform an interference analysis between array values.
30 ///
31 /// Lowering will generate a sequence of the following form.
32 /// ```mlir
33 ///   %a_1 = fir.array_load %array_1(%shape) : ...
34 ///   ...
35 ///   %a_j = fir.array_load %array_j(%shape) : ...
36 ///   ...
37 ///   %a_n = fir.array_load %array_n(%shape) : ...
38 ///     ...
39 ///     %v_i = fir.array_fetch %a_i, ...
40 ///     %a_j1 = fir.array_update %a_j, ...
41 ///     ...
42 ///   fir.array_merge_store %a_j, %a_jn to %array_j : ...
43 /// ```
44 ///
45 /// The analysis is to determine if there are any conflicts. A conflict is when
46 /// one the following cases occurs.
47 ///
48 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
49 /// loaded from the same array memory reference (array_j) but with a different
50 /// shape as the other array values a_i, where i != j. [Possible overlapping
51 /// arrays.]
52 ///
53 /// 2. There is either an array_fetch or array_update of a_j with a different
54 /// set of index values. [Possible loop-carried dependence.]
55 ///
56 /// If none of the array values overlap in storage and the accesses are not
57 /// loop-carried, then the arrays are conflict-free and no copies are required.
58 class ArrayCopyAnalysis {
59 public:
60   using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
61   using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
62   using LoadMapSetsT =
63       llvm::DenseMap<mlir::Operation *, SmallVector<Operation *>>;
64 
65   ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); }
66 
67   mlir::Operation *getOperation() const { return operation; }
68 
69   /// Return true iff the `array_merge_store` has potential conflicts.
70   bool hasPotentialConflict(mlir::Operation *op) const {
71     LLVM_DEBUG(llvm::dbgs()
72                << "looking for a conflict on " << *op
73                << " and the set has a total of " << conflicts.size() << '\n');
74     return conflicts.contains(op);
75   }
76 
77   /// Return the use map. The use map maps array fetch and update operations
78   /// back to the array load that is the original source of the array value.
79   const OperationUseMapT &getUseMap() const { return useMap; }
80 
81   /// Find all the array operations that access the array value that is loaded
82   /// by the array load operation, `load`.
83   const llvm::SmallVector<mlir::Operation *> &arrayAccesses(ArrayLoadOp load);
84 
85 private:
86   void construct(mlir::Operation *topLevelOp);
87 
88   mlir::Operation *operation; // operation that analysis ran upon
89   ConflictSetT conflicts;     // set of conflicts (loads and merge stores)
90   OperationUseMapT useMap;
91   LoadMapSetsT loadMapSets;
92 };
93 } // namespace
94 
95 namespace {
96 /// Helper class to collect all array operations that produced an array value.
97 class ReachCollector {
98 private:
99   // If provided, the `loopRegion` is the body of a loop that produces the array
100   // of interest.
101   ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
102                  mlir::Region *loopRegion)
103       : reach{reach}, loopRegion{loopRegion} {}
104 
105   void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) {
106     llvm::errs() << "COLLECT " << *op << "\n";
107     if (range.empty()) {
108       collectArrayAccessFrom(op, mlir::Value{});
109       return;
110     }
111     for (mlir::Value v : range)
112       collectArrayAccessFrom(v);
113   }
114 
115   // TODO: Replace recursive algorithm on def-use chain with an iterative one
116   // with an explicit stack.
117   void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) {
118     // `val` is defined by an Op, process the defining Op.
119     // If `val` is defined by a region containing Op, we want to drill down
120     // and through that Op's region(s).
121     llvm::errs() << "COLLECT " << *op << "\n";
122     LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
123     auto popFn = [&](auto rop) {
124       assert(val && "op must have a result value");
125       auto resNum = val.cast<mlir::OpResult>().getResultNumber();
126       llvm::SmallVector<mlir::Value> results;
127       rop.resultToSourceOps(results, resNum);
128       for (auto u : results)
129         collectArrayAccessFrom(u);
130     };
131     if (auto rop = mlir::dyn_cast<fir::DoLoopOp>(op)) {
132       popFn(rop);
133       return;
134     }
135     if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
136       popFn(rop);
137       return;
138     }
139     if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
140       if (opIsInsideLoops(mergeStore))
141         collectArrayAccessFrom(mergeStore.sequence());
142       return;
143     }
144 
145     if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
146       // Look for any stores inside the loops, and collect an array operation
147       // that produced the value being stored to it.
148       for (mlir::Operation *user : op->getUsers())
149         if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
150           if (opIsInsideLoops(store))
151             collectArrayAccessFrom(store.value());
152       return;
153     }
154 
155     // Otherwise, Op does not contain a region so just chase its operands.
156     if (mlir::isa<ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, ArrayFetchOp>(
157             op)) {
158       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
159       reach.emplace_back(op);
160     }
161     // Array modify assignment is performed on the result. So the analysis
162     // must look at the what is done with the result.
163     if (mlir::isa<ArrayModifyOp>(op))
164       for (mlir::Operation *user : op->getResult(0).getUsers())
165         followUsers(user);
166 
167     for (auto u : op->getOperands())
168       collectArrayAccessFrom(u);
169   }
170 
171   void collectArrayAccessFrom(mlir::BlockArgument ba) {
172     auto *parent = ba.getOwner()->getParentOp();
173     // If inside an Op holding a region, the block argument corresponds to an
174     // argument passed to the containing Op.
175     auto popFn = [&](auto rop) {
176       collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
177     };
178     if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
179       popFn(rop);
180       return;
181     }
182     if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
183       popFn(rop);
184       return;
185     }
186     // Otherwise, a block argument is provided via the pred blocks.
187     for (auto *pred : ba.getOwner()->getPredecessors()) {
188       auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
189       collectArrayAccessFrom(u);
190     }
191   }
192 
193   // Recursively trace operands to find all array operations relating to the
194   // values merged.
195   void collectArrayAccessFrom(mlir::Value val) {
196     if (!val || visited.contains(val))
197       return;
198     visited.insert(val);
199 
200     // Process a block argument.
201     if (auto ba = val.dyn_cast<mlir::BlockArgument>()) {
202       collectArrayAccessFrom(ba);
203       return;
204     }
205 
206     // Process an Op.
207     if (auto *op = val.getDefiningOp()) {
208       collectArrayAccessFrom(op, val);
209       return;
210     }
211 
212     fir::emitFatalError(val.getLoc(), "unhandled value");
213   }
214 
215   /// Is \op inside the loop nest region ?
216   bool opIsInsideLoops(mlir::Operation *op) const {
217     return loopRegion && loopRegion->isAncestor(op->getParentRegion());
218   }
219 
220   /// Recursively trace the use of an operation results, calling
221   /// collectArrayAccessFrom on the direct and indirect user operands.
222   /// TODO: Replace recursive algorithm on def-use chain with an iterative one
223   /// with an explicit stack.
224   void followUsers(mlir::Operation *op) {
225     for (auto userOperand : op->getOperands())
226       collectArrayAccessFrom(userOperand);
227     // Go through potential converts/coordinate_op.
228     for (mlir::Operation *indirectUser : op->getUsers())
229       followUsers(indirectUser);
230   }
231 
232   llvm::SmallVectorImpl<mlir::Operation *> &reach;
233   llvm::SmallPtrSet<mlir::Value, 16> visited;
234   /// Region of the loops nest that produced the array value.
235   mlir::Region *loopRegion;
236 
237 public:
238   /// Return all ops that produce the array value that is stored into the
239   /// `array_merge_store`.
240   static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
241                              mlir::Value seq) {
242     reach.clear();
243     mlir::Region *loopRegion = nullptr;
244     // Only `DoLoopOp` is tested here since array operations are currently only
245     // associated with this kind of loop.
246     if (auto doLoop =
247             mlir::dyn_cast_or_null<fir::DoLoopOp>(seq.getDefiningOp()))
248       loopRegion = &doLoop->getRegion(0);
249     ReachCollector collector(reach, loopRegion);
250     collector.collectArrayAccessFrom(seq);
251   }
252 };
253 } // namespace
254 
255 /// Find all the array operations that access the array value that is loaded by
256 /// the array load operation, `load`.
257 const llvm::SmallVector<mlir::Operation *> &
258 ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) {
259   auto lmIter = loadMapSets.find(load);
260   if (lmIter != loadMapSets.end())
261     return lmIter->getSecond();
262 
263   llvm::SmallVector<mlir::Operation *> accesses;
264   UseSetT visited;
265   llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
266 
267   auto appendToQueue = [&](mlir::Value val) {
268     for (mlir::OpOperand &use : val.getUses())
269       if (!visited.count(&use)) {
270         visited.insert(&use);
271         queue.push_back(&use);
272       }
273   };
274 
275   // Build the set of uses of `original`.
276   // let USES = { uses of original fir.load }
277   appendToQueue(load);
278 
279   // Process the worklist until done.
280   while (!queue.empty()) {
281     mlir::OpOperand *operand = queue.pop_back_val();
282     mlir::Operation *owner = operand->getOwner();
283 
284     auto structuredLoop = [&](auto ro) {
285       if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
286         int64_t arg = blockArg.getArgNumber();
287         mlir::Value output = ro.getResult(ro.finalValue() ? arg : arg - 1);
288         appendToQueue(output);
289         appendToQueue(blockArg);
290       }
291     };
292     // TODO: this need to be updated to use the control-flow interface.
293     auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
294       if (operands.empty())
295         return;
296 
297       // Check if this operand is within the range.
298       unsigned operandIndex = operand->getOperandNumber();
299       unsigned operandsStart = operands.getBeginOperandIndex();
300       if (operandIndex < operandsStart ||
301           operandIndex >= (operandsStart + operands.size()))
302         return;
303 
304       // Index the successor.
305       unsigned argIndex = operandIndex - operandsStart;
306       appendToQueue(dest->getArgument(argIndex));
307     };
308     // Thread uses into structured loop bodies and return value uses.
309     if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
310       structuredLoop(ro);
311     } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
312       structuredLoop(ro);
313     } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
314       // Thread any uses of fir.if that return the marked array value.
315       if (auto ifOp = rs->getParentOfType<fir::IfOp>())
316         appendToQueue(ifOp.getResult(operand->getOperandNumber()));
317     } else if (mlir::isa<ArrayFetchOp>(owner)) {
318       // Keep track of array value fetches.
319       LLVM_DEBUG(llvm::dbgs()
320                  << "add fetch {" << *owner << "} to array value set\n");
321       accesses.push_back(owner);
322     } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
323       // Keep track of array value updates and thread the return value uses.
324       LLVM_DEBUG(llvm::dbgs()
325                  << "add update {" << *owner << "} to array value set\n");
326       accesses.push_back(owner);
327       appendToQueue(update.getResult());
328     } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
329       // Keep track of array value modification and thread the return value
330       // uses.
331       LLVM_DEBUG(llvm::dbgs()
332                  << "add modify {" << *owner << "} to array value set\n");
333       accesses.push_back(owner);
334       appendToQueue(update.getResult(1));
335     } else if (auto br = mlir::dyn_cast<mlir::BranchOp>(owner)) {
336       branchOp(br.getDest(), br.getDestOperands());
337     } else if (auto br = mlir::dyn_cast<mlir::CondBranchOp>(owner)) {
338       branchOp(br.getTrueDest(), br.getTrueOperands());
339       branchOp(br.getFalseDest(), br.getFalseOperands());
340     } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
341       // do nothing
342     } else {
343       llvm::report_fatal_error("array value reached unexpected op");
344     }
345   }
346   return loadMapSets.insert({load, accesses}).first->getSecond();
347 }
348 
349 /// Is there a conflict between the array value that was updated and to be
350 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
351 /// the updated value?
352 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
353                            ArrayMergeStoreOp st) {
354   mlir::Value load;
355   mlir::Value addr = st.memref();
356   auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType());
357   for (auto *op : reach) {
358     auto ld = mlir::dyn_cast<ArrayLoadOp>(op);
359     if (!ld)
360       continue;
361     mlir::Type ldTy = ld.memref().getType();
362     if (auto boxTy = ldTy.dyn_cast<fir::BoxType>())
363       ldTy = boxTy.getEleTy();
364     if (ldTy.isa<fir::PointerType>() && stEleTy == dyn_cast_ptrEleTy(ldTy))
365       return true;
366     if (ld.memref() == addr) {
367       if (ld.getResult() != st.original())
368         return true;
369       if (load)
370         return true;
371       load = ld;
372     }
373   }
374   return false;
375 }
376 
377 /// Check if there is any potential conflict in the chained update operations
378 /// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the
379 /// array. A potential conflict is detected if two operations work on the same
380 /// indices.
381 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> accesses) {
382   if (accesses.size() < 2)
383     return false;
384   llvm::SmallVector<mlir::Value> indices;
385   LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size()
386                           << " accesses on the list\n");
387   for (auto *op : accesses) {
388     assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) &&
389            "unexpected operation in analysis");
390     llvm::SmallVector<mlir::Value> compareVector;
391     if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
392       if (indices.empty()) {
393         indices = u.indices();
394         continue;
395       }
396       compareVector = u.indices();
397     } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
398       if (indices.empty()) {
399         indices = f.indices();
400         continue;
401       }
402       compareVector = f.indices();
403     } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
404       if (indices.empty()) {
405         indices = f.indices();
406         continue;
407       }
408       compareVector = f.indices();
409     }
410     if (compareVector != indices)
411       return true;
412     LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
413   }
414   return false;
415 }
416 
417 // Are either of types of conflicts present?
418 inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
419                              llvm::ArrayRef<mlir::Operation *> accesses,
420                              ArrayMergeStoreOp st) {
421   return conflictOnLoad(reach, st) || conflictOnMerge(accesses);
422 }
423 
424 /// Constructor of the array copy analysis.
425 /// This performs the analysis and saves the intermediate results.
426 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) {
427   topLevelOp->walk([&](Operation *op) {
428     if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
429       llvm::SmallVector<Operation *> values;
430       ReachCollector::reachingValues(values, st.sequence());
431       const llvm::SmallVector<Operation *> &accesses =
432           arrayAccesses(mlir::cast<ArrayLoadOp>(st.original().getDefiningOp()));
433       if (conflictDetected(values, accesses, st)) {
434         LLVM_DEBUG(llvm::dbgs()
435                    << "CONFLICT: copies required for " << st << '\n'
436                    << "   adding conflicts on: " << op << " and "
437                    << st.original() << '\n');
438         conflicts.insert(op);
439         conflicts.insert(st.original().getDefiningOp());
440       }
441       auto *ld = st.original().getDefiningOp();
442       LLVM_DEBUG(llvm::dbgs()
443                  << "map: adding {" << *ld << " -> " << st << "}\n");
444       useMap.insert({ld, op});
445     } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
446       const llvm::SmallVector<mlir::Operation *> &accesses =
447           arrayAccesses(load);
448       LLVM_DEBUG(llvm::dbgs() << "process load: " << load
449                               << ", accesses: " << accesses.size() << '\n');
450       for (auto *acc : accesses) {
451         LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n');
452         assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(acc)));
453         if (!useMap.insert({acc, op}).second) {
454           mlir::emitError(
455               load.getLoc(),
456               "The parallel semantics of multiple array_merge_stores per "
457               "array_load are not supported.");
458           return;
459         }
460         LLVM_DEBUG(llvm::dbgs()
461                    << "map: adding {" << *acc << "} -> {" << load << "}\n");
462       }
463     }
464   });
465 }
466 
467 namespace {
468 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
469 public:
470   using OpRewritePattern::OpRewritePattern;
471 
472   mlir::LogicalResult
473   matchAndRewrite(ArrayLoadOp load,
474                   mlir::PatternRewriter &rewriter) const override {
475     LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
476     rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
477     return mlir::success();
478   }
479 };
480 
481 class ArrayMergeStoreConversion
482     : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
483 public:
484   using OpRewritePattern::OpRewritePattern;
485 
486   mlir::LogicalResult
487   matchAndRewrite(ArrayMergeStoreOp store,
488                   mlir::PatternRewriter &rewriter) const override {
489     LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
490     rewriter.eraseOp(store);
491     return mlir::success();
492   }
493 };
494 } // namespace
495 
496 static mlir::Type getEleTy(mlir::Type ty) {
497   if (auto t = dyn_cast_ptrEleTy(ty))
498     ty = t;
499   if (auto t = ty.dyn_cast<SequenceType>())
500     ty = t.getEleTy();
501   // FIXME: keep ptr/heap/ref information.
502   return ReferenceType::get(ty);
503 }
504 
505 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
506 // TODO: getExtents on op should return a ValueRange instead of a vector.
507 static void getExtents(llvm::SmallVectorImpl<mlir::Value> &result,
508                        mlir::Value shape) {
509   auto *shapeOp = shape.getDefiningOp();
510   if (auto s = mlir::dyn_cast<fir::ShapeOp>(shapeOp)) {
511     auto e = s.getExtents();
512     result.insert(result.end(), e.begin(), e.end());
513     return;
514   }
515   if (auto s = mlir::dyn_cast<fir::ShapeShiftOp>(shapeOp)) {
516     auto e = s.getExtents();
517     result.insert(result.end(), e.begin(), e.end());
518     return;
519   }
520   llvm::report_fatal_error("not a fir.shape/fir.shape_shift op");
521 }
522 
523 // Place the extents of the array loaded by an ArrayLoadOp into the result
524 // vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If
525 // the ArrayLoadOp is loading a fir.box, code will be generated to read the
526 // extents from the fir.box, and a the retunred ShapeOp is built with the read
527 // extents.
528 // Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp
529 // argument of the ArrayLoadOp that is returned.
530 static mlir::Value
531 getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter,
532                            fir::ArrayLoadOp loadOp,
533                            llvm::SmallVectorImpl<mlir::Value> &result) {
534   assert(result.empty());
535   if (auto boxTy = loadOp.memref().getType().dyn_cast<fir::BoxType>()) {
536     auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy)
537                     .cast<fir::SequenceType>()
538                     .getDimension();
539     auto idxTy = rewriter.getIndexType();
540     for (decltype(rank) dim = 0; dim < rank; ++dim) {
541       auto dimVal = rewriter.create<arith::ConstantIndexOp>(loc, dim);
542       auto dimInfo = rewriter.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
543                                                      loadOp.memref(), dimVal);
544       result.emplace_back(dimInfo.getResult(1));
545     }
546     auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank);
547     return rewriter.create<fir::ShapeOp>(loc, shapeType, result);
548   }
549   getExtents(result, loadOp.shape());
550   return loadOp.shape();
551 }
552 
553 static mlir::Type toRefType(mlir::Type ty) {
554   if (fir::isa_ref_type(ty))
555     return ty;
556   return fir::ReferenceType::get(ty);
557 }
558 
559 static mlir::Value
560 genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy,
561           mlir::Type resTy, mlir::Value alloc, mlir::Value shape,
562           mlir::Value slice, mlir::ValueRange indices,
563           mlir::ValueRange typeparams, bool skipOrig = false) {
564   llvm::SmallVector<mlir::Value> originated;
565   if (skipOrig)
566     originated.assign(indices.begin(), indices.end());
567   else
568     originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(),
569                                                 shape, indices);
570   auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType());
571   assert(seqTy && seqTy.isa<fir::SequenceType>());
572   const auto dimension = seqTy.cast<fir::SequenceType>().getDimension();
573   mlir::Value result = rewriter.create<fir::ArrayCoorOp>(
574       loc, eleTy, alloc, shape, slice,
575       llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
576       typeparams);
577   if (dimension < originated.size())
578     result = rewriter.create<fir::CoordinateOp>(
579         loc, resTy, result,
580         llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
581   return result;
582 }
583 
584 namespace {
585 /// Conversion of fir.array_update and fir.array_modify Ops.
586 /// If there is a conflict for the update, then we need to perform a
587 /// copy-in/copy-out to preserve the original values of the array. If there is
588 /// no conflict, then it is save to eschew making any copies.
589 template <typename ArrayOp>
590 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
591 public:
592   explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
593                                      const ArrayCopyAnalysis &a,
594                                      const OperationUseMapT &m)
595       : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
596 
597   void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
598                     mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
599                     mlir::Type arrTy) const {
600     auto insPt = rewriter.saveInsertionPoint();
601     llvm::SmallVector<mlir::Value> indices;
602     llvm::SmallVector<mlir::Value> extents;
603     getExtents(extents, shapeOp);
604     // Build loop nest from column to row.
605     for (auto sh : llvm::reverse(extents)) {
606       auto idxTy = rewriter.getIndexType();
607       auto ubi = rewriter.create<fir::ConvertOp>(loc, idxTy, sh);
608       auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
609       auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
610       auto ub = rewriter.create<arith::SubIOp>(loc, idxTy, ubi, one);
611       auto loop = rewriter.create<fir::DoLoopOp>(loc, zero, ub, one);
612       rewriter.setInsertionPointToStart(loop.getBody());
613       indices.push_back(loop.getInductionVar());
614     }
615     // Reverse the indices so they are in column-major order.
616     std::reverse(indices.begin(), indices.end());
617     auto ty = getEleTy(arrTy);
618     auto fromAddr = rewriter.create<fir::ArrayCoorOp>(
619         loc, ty, src, shapeOp, mlir::Value{},
620         fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp,
621                                        indices),
622         mlir::ValueRange{});
623     auto load = rewriter.create<fir::LoadOp>(loc, fromAddr);
624     auto toAddr = rewriter.create<fir::ArrayCoorOp>(
625         loc, ty, dst, shapeOp, mlir::Value{},
626         fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp,
627                                        indices),
628         mlir::ValueRange{});
629     rewriter.create<fir::StoreOp>(loc, load, toAddr);
630     rewriter.restoreInsertionPoint(insPt);
631   }
632 
633   /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
634   /// temp and the LHS if the analysis found potential overlaps between the RHS
635   /// and LHS arrays. The element copy generator must be provided through \p
636   /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
637   /// Returns the address of the LHS element inside the loop and the LHS
638   /// ArrayLoad result.
639   std::pair<mlir::Value, mlir::Value>
640   materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
641                         ArrayOp update,
642                         llvm::function_ref<void(mlir::Value)> assignElement,
643                         mlir::Type lhsEltRefType) const {
644     auto *op = update.getOperation();
645     mlir::Operation *loadOp = useMap.lookup(op);
646     auto load = mlir::cast<ArrayLoadOp>(loadOp);
647     LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
648     if (analysis.hasPotentialConflict(loadOp)) {
649       // If there is a conflict between the arrays, then we copy the lhs array
650       // to a temporary, update the temporary, and copy the temporary back to
651       // the lhs array. This yields Fortran's copy-in copy-out array semantics.
652       LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
653       rewriter.setInsertionPoint(loadOp);
654       // Copy in.
655       llvm::SmallVector<mlir::Value> extents;
656       mlir::Value shapeOp =
657           getOrReadExtentsAndShapeOp(loc, rewriter, load, extents);
658       auto allocmem = rewriter.create<AllocMemOp>(
659           loc, dyn_cast_ptrOrBoxEleTy(load.memref().getType()),
660           load.typeparams(), extents);
661       genArrayCopy(load.getLoc(), rewriter, allocmem, load.memref(), shapeOp,
662                    load.getType());
663       rewriter.setInsertionPoint(op);
664       mlir::Value coor = genCoorOp(
665           rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
666           shapeOp, load.slice(), update.indices(), load.typeparams(),
667           update->hasAttr(fir::factory::attrFortranArrayOffsets()));
668       assignElement(coor);
669       mlir::Operation *storeOp = useMap.lookup(loadOp);
670       auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
671       rewriter.setInsertionPoint(storeOp);
672       // Copy out.
673       genArrayCopy(store.getLoc(), rewriter, store.memref(), allocmem, shapeOp,
674                    load.getType());
675       rewriter.create<FreeMemOp>(loc, allocmem);
676       return {coor, load.getResult()};
677     }
678     // Otherwise, when there is no conflict (a possible loop-carried
679     // dependence), the lhs array can be updated in place.
680     LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
681     rewriter.setInsertionPoint(op);
682     auto coorTy = getEleTy(load.getType());
683     mlir::Value coor = genCoorOp(
684         rewriter, loc, coorTy, lhsEltRefType, load.memref(), load.shape(),
685         load.slice(), update.indices(), load.typeparams(),
686         update->hasAttr(fir::factory::attrFortranArrayOffsets()));
687     assignElement(coor);
688     return {coor, load.getResult()};
689   }
690 
691 private:
692   const ArrayCopyAnalysis &analysis;
693   const OperationUseMapT &useMap;
694 };
695 
696 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
697 public:
698   explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
699                                  const ArrayCopyAnalysis &a,
700                                  const OperationUseMapT &m)
701       : ArrayUpdateConversionBase{ctx, a, m} {}
702 
703   mlir::LogicalResult
704   matchAndRewrite(ArrayUpdateOp update,
705                   mlir::PatternRewriter &rewriter) const override {
706     auto loc = update.getLoc();
707     auto assignElement = [&](mlir::Value coor) {
708       rewriter.create<fir::StoreOp>(loc, update.merge(), coor);
709     };
710     auto lhsEltRefType = toRefType(update.merge().getType());
711     auto [_, lhsLoadResult] = materializeAssignment(
712         loc, rewriter, update, assignElement, lhsEltRefType);
713     update.replaceAllUsesWith(lhsLoadResult);
714     rewriter.replaceOp(update, lhsLoadResult);
715     return mlir::success();
716   }
717 };
718 
719 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
720 public:
721   explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
722                                  const ArrayCopyAnalysis &a,
723                                  const OperationUseMapT &m)
724       : ArrayUpdateConversionBase{ctx, a, m} {}
725 
726   mlir::LogicalResult
727   matchAndRewrite(ArrayModifyOp modify,
728                   mlir::PatternRewriter &rewriter) const override {
729     auto loc = modify.getLoc();
730     auto assignElement = [](mlir::Value) {
731       // Assignment already materialized by lowering using lhs element address.
732     };
733     auto lhsEltRefType = modify.getResult(0).getType();
734     auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
735         loc, rewriter, modify, assignElement, lhsEltRefType);
736     modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
737     rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
738     return mlir::success();
739   }
740 };
741 
742 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
743 public:
744   explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
745                                 const OperationUseMapT &m)
746       : OpRewritePattern{ctx}, useMap{m} {}
747 
748   mlir::LogicalResult
749   matchAndRewrite(ArrayFetchOp fetch,
750                   mlir::PatternRewriter &rewriter) const override {
751     auto *op = fetch.getOperation();
752     rewriter.setInsertionPoint(op);
753     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
754     auto loc = fetch.getLoc();
755     mlir::Value coor =
756         genCoorOp(rewriter, loc, getEleTy(load.getType()),
757                   toRefType(fetch.getType()), load.memref(), load.shape(),
758                   load.slice(), fetch.indices(), load.typeparams(),
759                   fetch->hasAttr(fir::factory::attrFortranArrayOffsets()));
760     rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
761     return mlir::success();
762   }
763 
764 private:
765   const OperationUseMapT &useMap;
766 };
767 } // namespace
768 
769 namespace {
770 class ArrayValueCopyConverter
771     : public ArrayValueCopyBase<ArrayValueCopyConverter> {
772 public:
773   void runOnOperation() override {
774     auto func = getOperation();
775     LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
776                             << func.getName() << "'\n");
777     auto *context = &getContext();
778 
779     // Perform the conflict analysis.
780     auto &analysis = getAnalysis<ArrayCopyAnalysis>();
781     const auto &useMap = analysis.getUseMap();
782 
783     // Phase 1 is performing a rewrite on the array accesses. Once all the
784     // array accesses are rewritten we can go on phase 2.
785     // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in
786     // /copy-out refers the Fortran copy-in/copy-out semantics on statements.
787     mlir::OwningRewritePatternList patterns1(context);
788     patterns1.insert<ArrayFetchConversion>(context, useMap);
789     patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
790     patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
791     mlir::ConversionTarget target(*context);
792     target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
793                            mlir::arith::ArithmeticDialect,
794                            mlir::StandardOpsDialect>();
795     target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>();
796     // Rewrite the array fetch and array update ops.
797     if (mlir::failed(
798             mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
799       mlir::emitError(mlir::UnknownLoc::get(context),
800                       "failure in array-value-copy pass, phase 1");
801       signalPassFailure();
802     }
803 
804     mlir::OwningRewritePatternList patterns2(context);
805     patterns2.insert<ArrayLoadConversion>(context);
806     patterns2.insert<ArrayMergeStoreConversion>(context);
807     target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
808     if (mlir::failed(
809             mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
810       mlir::emitError(mlir::UnknownLoc::get(context),
811                       "failure in array-value-copy pass, phase 2");
812       signalPassFailure();
813     }
814   }
815 };
816 } // namespace
817 
818 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() {
819   return std::make_unique<ArrayValueCopyConverter>();
820 }
821