xref: /llvm-project/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp (revision 149ad3d554c6e901a57c7e34e29fba334dcd283c)
1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/BoxValue.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/Factory.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "flang-array-value-copy"
22 
23 using namespace fir;
24 
25 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
26 
27 namespace {
28 
29 /// Array copy analysis.
30 /// Perform an interference analysis between array values.
31 ///
32 /// Lowering will generate a sequence of the following form.
33 /// ```mlir
34 ///   %a_1 = fir.array_load %array_1(%shape) : ...
35 ///   ...
36 ///   %a_j = fir.array_load %array_j(%shape) : ...
37 ///   ...
38 ///   %a_n = fir.array_load %array_n(%shape) : ...
39 ///     ...
40 ///     %v_i = fir.array_fetch %a_i, ...
41 ///     %a_j1 = fir.array_update %a_j, ...
42 ///     ...
43 ///   fir.array_merge_store %a_j, %a_jn to %array_j : ...
44 /// ```
45 ///
46 /// The analysis is to determine if there are any conflicts. A conflict is when
47 /// one the following cases occurs.
48 ///
49 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
50 /// loaded from the same array memory reference (array_j) but with a different
51 /// shape as the other array values a_i, where i != j. [Possible overlapping
52 /// arrays.]
53 ///
54 /// 2. There is either an array_fetch or array_update of a_j with a different
55 /// set of index values. [Possible loop-carried dependence.]
56 ///
57 /// If none of the array values overlap in storage and the accesses are not
58 /// loop-carried, then the arrays are conflict-free and no copies are required.
59 class ArrayCopyAnalysis {
60 public:
61   using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
62   using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
63   using LoadMapSetsT =
64       llvm::DenseMap<mlir::Operation *, SmallVector<Operation *>>;
65 
66   ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); }
67 
68   mlir::Operation *getOperation() const { return operation; }
69 
70   /// Return true iff the `array_merge_store` has potential conflicts.
71   bool hasPotentialConflict(mlir::Operation *op) const {
72     LLVM_DEBUG(llvm::dbgs()
73                << "looking for a conflict on " << *op
74                << " and the set has a total of " << conflicts.size() << '\n');
75     return conflicts.contains(op);
76   }
77 
78   /// Return the use map. The use map maps array fetch and update operations
79   /// back to the array load that is the original source of the array value.
80   const OperationUseMapT &getUseMap() const { return useMap; }
81 
82   /// Find all the array operations that access the array value that is loaded
83   /// by the array load operation, `load`.
84   const llvm::SmallVector<mlir::Operation *> &arrayAccesses(ArrayLoadOp load);
85 
86 private:
87   void construct(mlir::Operation *topLevelOp);
88 
89   mlir::Operation *operation; // operation that analysis ran upon
90   ConflictSetT conflicts;     // set of conflicts (loads and merge stores)
91   OperationUseMapT useMap;
92   LoadMapSetsT loadMapSets;
93 };
94 } // namespace
95 
96 namespace {
97 /// Helper class to collect all array operations that produced an array value.
98 class ReachCollector {
99 private:
100   // If provided, the `loopRegion` is the body of a loop that produces the array
101   // of interest.
102   ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
103                  mlir::Region *loopRegion)
104       : reach{reach}, loopRegion{loopRegion} {}
105 
106   void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) {
107     llvm::errs() << "COLLECT " << *op << "\n";
108     if (range.empty()) {
109       collectArrayAccessFrom(op, mlir::Value{});
110       return;
111     }
112     for (mlir::Value v : range)
113       collectArrayAccessFrom(v);
114   }
115 
116   // TODO: Replace recursive algorithm on def-use chain with an iterative one
117   // with an explicit stack.
118   void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) {
119     // `val` is defined by an Op, process the defining Op.
120     // If `val` is defined by a region containing Op, we want to drill down
121     // and through that Op's region(s).
122     llvm::errs() << "COLLECT " << *op << "\n";
123     LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
124     auto popFn = [&](auto rop) {
125       assert(val && "op must have a result value");
126       auto resNum = val.cast<mlir::OpResult>().getResultNumber();
127       llvm::SmallVector<mlir::Value> results;
128       rop.resultToSourceOps(results, resNum);
129       for (auto u : results)
130         collectArrayAccessFrom(u);
131     };
132     if (auto rop = mlir::dyn_cast<fir::DoLoopOp>(op)) {
133       popFn(rop);
134       return;
135     }
136     if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
137       popFn(rop);
138       return;
139     }
140     if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
141       if (opIsInsideLoops(mergeStore))
142         collectArrayAccessFrom(mergeStore.getSequence());
143       return;
144     }
145 
146     if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
147       // Look for any stores inside the loops, and collect an array operation
148       // that produced the value being stored to it.
149       for (mlir::Operation *user : op->getUsers())
150         if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
151           if (opIsInsideLoops(store))
152             collectArrayAccessFrom(store.getValue());
153       return;
154     }
155 
156     // Otherwise, Op does not contain a region so just chase its operands.
157     if (mlir::isa<ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, ArrayFetchOp>(
158             op)) {
159       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
160       reach.emplace_back(op);
161     }
162     // Array modify assignment is performed on the result. So the analysis
163     // must look at the what is done with the result.
164     if (mlir::isa<ArrayModifyOp>(op))
165       for (mlir::Operation *user : op->getResult(0).getUsers())
166         followUsers(user);
167 
168     for (auto u : op->getOperands())
169       collectArrayAccessFrom(u);
170   }
171 
172   void collectArrayAccessFrom(mlir::BlockArgument ba) {
173     auto *parent = ba.getOwner()->getParentOp();
174     // If inside an Op holding a region, the block argument corresponds to an
175     // argument passed to the containing Op.
176     auto popFn = [&](auto rop) {
177       collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
178     };
179     if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
180       popFn(rop);
181       return;
182     }
183     if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
184       popFn(rop);
185       return;
186     }
187     // Otherwise, a block argument is provided via the pred blocks.
188     for (auto *pred : ba.getOwner()->getPredecessors()) {
189       auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
190       collectArrayAccessFrom(u);
191     }
192   }
193 
194   // Recursively trace operands to find all array operations relating to the
195   // values merged.
196   void collectArrayAccessFrom(mlir::Value val) {
197     if (!val || visited.contains(val))
198       return;
199     visited.insert(val);
200 
201     // Process a block argument.
202     if (auto ba = val.dyn_cast<mlir::BlockArgument>()) {
203       collectArrayAccessFrom(ba);
204       return;
205     }
206 
207     // Process an Op.
208     if (auto *op = val.getDefiningOp()) {
209       collectArrayAccessFrom(op, val);
210       return;
211     }
212 
213     fir::emitFatalError(val.getLoc(), "unhandled value");
214   }
215 
216   /// Is \op inside the loop nest region ?
217   bool opIsInsideLoops(mlir::Operation *op) const {
218     return loopRegion && loopRegion->isAncestor(op->getParentRegion());
219   }
220 
221   /// Recursively trace the use of an operation results, calling
222   /// collectArrayAccessFrom on the direct and indirect user operands.
223   /// TODO: Replace recursive algorithm on def-use chain with an iterative one
224   /// with an explicit stack.
225   void followUsers(mlir::Operation *op) {
226     for (auto userOperand : op->getOperands())
227       collectArrayAccessFrom(userOperand);
228     // Go through potential converts/coordinate_op.
229     for (mlir::Operation *indirectUser : op->getUsers())
230       followUsers(indirectUser);
231   }
232 
233   llvm::SmallVectorImpl<mlir::Operation *> &reach;
234   llvm::SmallPtrSet<mlir::Value, 16> visited;
235   /// Region of the loops nest that produced the array value.
236   mlir::Region *loopRegion;
237 
238 public:
239   /// Return all ops that produce the array value that is stored into the
240   /// `array_merge_store`.
241   static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
242                              mlir::Value seq) {
243     reach.clear();
244     mlir::Region *loopRegion = nullptr;
245     // Only `DoLoopOp` is tested here since array operations are currently only
246     // associated with this kind of loop.
247     if (auto doLoop =
248             mlir::dyn_cast_or_null<fir::DoLoopOp>(seq.getDefiningOp()))
249       loopRegion = &doLoop->getRegion(0);
250     ReachCollector collector(reach, loopRegion);
251     collector.collectArrayAccessFrom(seq);
252   }
253 };
254 } // namespace
255 
256 /// Find all the array operations that access the array value that is loaded by
257 /// the array load operation, `load`.
258 const llvm::SmallVector<mlir::Operation *> &
259 ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) {
260   auto lmIter = loadMapSets.find(load);
261   if (lmIter != loadMapSets.end())
262     return lmIter->getSecond();
263 
264   llvm::SmallVector<mlir::Operation *> accesses;
265   UseSetT visited;
266   llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
267 
268   auto appendToQueue = [&](mlir::Value val) {
269     for (mlir::OpOperand &use : val.getUses())
270       if (!visited.count(&use)) {
271         visited.insert(&use);
272         queue.push_back(&use);
273       }
274   };
275 
276   // Build the set of uses of `original`.
277   // let USES = { uses of original fir.load }
278   appendToQueue(load);
279 
280   // Process the worklist until done.
281   while (!queue.empty()) {
282     mlir::OpOperand *operand = queue.pop_back_val();
283     mlir::Operation *owner = operand->getOwner();
284 
285     auto structuredLoop = [&](auto ro) {
286       if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
287         int64_t arg = blockArg.getArgNumber();
288         mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1);
289         appendToQueue(output);
290         appendToQueue(blockArg);
291       }
292     };
293     // TODO: this need to be updated to use the control-flow interface.
294     auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
295       if (operands.empty())
296         return;
297 
298       // Check if this operand is within the range.
299       unsigned operandIndex = operand->getOperandNumber();
300       unsigned operandsStart = operands.getBeginOperandIndex();
301       if (operandIndex < operandsStart ||
302           operandIndex >= (operandsStart + operands.size()))
303         return;
304 
305       // Index the successor.
306       unsigned argIndex = operandIndex - operandsStart;
307       appendToQueue(dest->getArgument(argIndex));
308     };
309     // Thread uses into structured loop bodies and return value uses.
310     if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
311       structuredLoop(ro);
312     } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
313       structuredLoop(ro);
314     } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
315       // Thread any uses of fir.if that return the marked array value.
316       if (auto ifOp = rs->getParentOfType<fir::IfOp>())
317         appendToQueue(ifOp.getResult(operand->getOperandNumber()));
318     } else if (mlir::isa<ArrayFetchOp>(owner)) {
319       // Keep track of array value fetches.
320       LLVM_DEBUG(llvm::dbgs()
321                  << "add fetch {" << *owner << "} to array value set\n");
322       accesses.push_back(owner);
323     } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
324       // Keep track of array value updates and thread the return value uses.
325       LLVM_DEBUG(llvm::dbgs()
326                  << "add update {" << *owner << "} to array value set\n");
327       accesses.push_back(owner);
328       appendToQueue(update.getResult());
329     } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
330       // Keep track of array value modification and thread the return value
331       // uses.
332       LLVM_DEBUG(llvm::dbgs()
333                  << "add modify {" << *owner << "} to array value set\n");
334       accesses.push_back(owner);
335       appendToQueue(update.getResult(1));
336     } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
337       branchOp(br.getDest(), br.getDestOperands());
338     } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
339       branchOp(br.getTrueDest(), br.getTrueOperands());
340       branchOp(br.getFalseDest(), br.getFalseOperands());
341     } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
342       // do nothing
343     } else {
344       llvm::report_fatal_error("array value reached unexpected op");
345     }
346   }
347   return loadMapSets.insert({load, accesses}).first->getSecond();
348 }
349 
350 /// Is there a conflict between the array value that was updated and to be
351 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
352 /// the updated value?
353 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
354                            ArrayMergeStoreOp st) {
355   mlir::Value load;
356   mlir::Value addr = st.getMemref();
357   auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType());
358   for (auto *op : reach) {
359     auto ld = mlir::dyn_cast<ArrayLoadOp>(op);
360     if (!ld)
361       continue;
362     mlir::Type ldTy = ld.getMemref().getType();
363     if (auto boxTy = ldTy.dyn_cast<fir::BoxType>())
364       ldTy = boxTy.getEleTy();
365     if (ldTy.isa<fir::PointerType>() && stEleTy == dyn_cast_ptrEleTy(ldTy))
366       return true;
367     if (ld.getMemref() == addr) {
368       if (ld.getResult() != st.getOriginal())
369         return true;
370       if (load)
371         return true;
372       load = ld;
373     }
374   }
375   return false;
376 }
377 
378 /// Check if there is any potential conflict in the chained update operations
379 /// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the
380 /// array. A potential conflict is detected if two operations work on the same
381 /// indices.
382 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> accesses) {
383   if (accesses.size() < 2)
384     return false;
385   llvm::SmallVector<mlir::Value> indices;
386   LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size()
387                           << " accesses on the list\n");
388   for (auto *op : accesses) {
389     assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) &&
390            "unexpected operation in analysis");
391     llvm::SmallVector<mlir::Value> compareVector;
392     if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
393       if (indices.empty()) {
394         indices = u.getIndices();
395         continue;
396       }
397       compareVector = u.getIndices();
398     } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
399       if (indices.empty()) {
400         indices = f.getIndices();
401         continue;
402       }
403       compareVector = f.getIndices();
404     } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
405       if (indices.empty()) {
406         indices = f.getIndices();
407         continue;
408       }
409       compareVector = f.getIndices();
410     }
411     if (compareVector != indices)
412       return true;
413     LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
414   }
415   return false;
416 }
417 
418 // Are either of types of conflicts present?
419 inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
420                              llvm::ArrayRef<mlir::Operation *> accesses,
421                              ArrayMergeStoreOp st) {
422   return conflictOnLoad(reach, st) || conflictOnMerge(accesses);
423 }
424 
425 /// Constructor of the array copy analysis.
426 /// This performs the analysis and saves the intermediate results.
427 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) {
428   topLevelOp->walk([&](Operation *op) {
429     if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
430       llvm::SmallVector<Operation *> values;
431       ReachCollector::reachingValues(values, st.getSequence());
432       const llvm::SmallVector<Operation *> &accesses = arrayAccesses(
433           mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp()));
434       if (conflictDetected(values, accesses, st)) {
435         LLVM_DEBUG(llvm::dbgs()
436                    << "CONFLICT: copies required for " << st << '\n'
437                    << "   adding conflicts on: " << op << " and "
438                    << st.getOriginal() << '\n');
439         conflicts.insert(op);
440         conflicts.insert(st.getOriginal().getDefiningOp());
441       }
442       auto *ld = st.getOriginal().getDefiningOp();
443       LLVM_DEBUG(llvm::dbgs()
444                  << "map: adding {" << *ld << " -> " << st << "}\n");
445       useMap.insert({ld, op});
446     } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
447       const llvm::SmallVector<mlir::Operation *> &accesses =
448           arrayAccesses(load);
449       LLVM_DEBUG(llvm::dbgs() << "process load: " << load
450                               << ", accesses: " << accesses.size() << '\n');
451       for (auto *acc : accesses) {
452         LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n');
453         assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(acc)));
454         if (!useMap.insert({acc, op}).second) {
455           mlir::emitError(
456               load.getLoc(),
457               "The parallel semantics of multiple array_merge_stores per "
458               "array_load are not supported.");
459           return;
460         }
461         LLVM_DEBUG(llvm::dbgs()
462                    << "map: adding {" << *acc << "} -> {" << load << "}\n");
463       }
464     }
465   });
466 }
467 
468 namespace {
469 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
470 public:
471   using OpRewritePattern::OpRewritePattern;
472 
473   mlir::LogicalResult
474   matchAndRewrite(ArrayLoadOp load,
475                   mlir::PatternRewriter &rewriter) const override {
476     LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
477     rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
478     return mlir::success();
479   }
480 };
481 
482 class ArrayMergeStoreConversion
483     : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
484 public:
485   using OpRewritePattern::OpRewritePattern;
486 
487   mlir::LogicalResult
488   matchAndRewrite(ArrayMergeStoreOp store,
489                   mlir::PatternRewriter &rewriter) const override {
490     LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
491     rewriter.eraseOp(store);
492     return mlir::success();
493   }
494 };
495 } // namespace
496 
497 static mlir::Type getEleTy(mlir::Type ty) {
498   if (auto t = dyn_cast_ptrEleTy(ty))
499     ty = t;
500   if (auto t = ty.dyn_cast<SequenceType>())
501     ty = t.getEleTy();
502   // FIXME: keep ptr/heap/ref information.
503   return ReferenceType::get(ty);
504 }
505 
506 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
507 // TODO: getExtents on op should return a ValueRange instead of a vector.
508 static void getExtents(llvm::SmallVectorImpl<mlir::Value> &result,
509                        mlir::Value shape) {
510   auto *shapeOp = shape.getDefiningOp();
511   if (auto s = mlir::dyn_cast<fir::ShapeOp>(shapeOp)) {
512     auto e = s.getExtents();
513     result.insert(result.end(), e.begin(), e.end());
514     return;
515   }
516   if (auto s = mlir::dyn_cast<fir::ShapeShiftOp>(shapeOp)) {
517     auto e = s.getExtents();
518     result.insert(result.end(), e.begin(), e.end());
519     return;
520   }
521   llvm::report_fatal_error("not a fir.shape/fir.shape_shift op");
522 }
523 
524 // Place the extents of the array loaded by an ArrayLoadOp into the result
525 // vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If
526 // the ArrayLoadOp is loading a fir.box, code will be generated to read the
527 // extents from the fir.box, and a the retunred ShapeOp is built with the read
528 // extents.
529 // Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp
530 // argument of the ArrayLoadOp that is returned.
531 static mlir::Value
532 getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter,
533                            fir::ArrayLoadOp loadOp,
534                            llvm::SmallVectorImpl<mlir::Value> &result) {
535   assert(result.empty());
536   if (auto boxTy = loadOp.getMemref().getType().dyn_cast<fir::BoxType>()) {
537     auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy)
538                     .cast<fir::SequenceType>()
539                     .getDimension();
540     auto idxTy = rewriter.getIndexType();
541     for (decltype(rank) dim = 0; dim < rank; ++dim) {
542       auto dimVal = rewriter.create<arith::ConstantIndexOp>(loc, dim);
543       auto dimInfo = rewriter.create<fir::BoxDimsOp>(
544           loc, idxTy, idxTy, idxTy, loadOp.getMemref(), dimVal);
545       result.emplace_back(dimInfo.getResult(1));
546     }
547     auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank);
548     return rewriter.create<fir::ShapeOp>(loc, shapeType, result);
549   }
550   getExtents(result, loadOp.getShape());
551   return loadOp.getShape();
552 }
553 
554 static mlir::Type toRefType(mlir::Type ty) {
555   if (fir::isa_ref_type(ty))
556     return ty;
557   return fir::ReferenceType::get(ty);
558 }
559 
560 static mlir::Value
561 genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy,
562           mlir::Type resTy, mlir::Value alloc, mlir::Value shape,
563           mlir::Value slice, mlir::ValueRange indices,
564           mlir::ValueRange typeparams, bool skipOrig = false) {
565   llvm::SmallVector<mlir::Value> originated;
566   if (skipOrig)
567     originated.assign(indices.begin(), indices.end());
568   else
569     originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(),
570                                                 shape, indices);
571   auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType());
572   assert(seqTy && seqTy.isa<fir::SequenceType>());
573   const auto dimension = seqTy.cast<fir::SequenceType>().getDimension();
574   mlir::Value result = rewriter.create<fir::ArrayCoorOp>(
575       loc, eleTy, alloc, shape, slice,
576       llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
577       typeparams);
578   if (dimension < originated.size())
579     result = rewriter.create<fir::CoordinateOp>(
580         loc, resTy, result,
581         llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
582   return result;
583 }
584 
585 namespace {
586 /// Conversion of fir.array_update and fir.array_modify Ops.
587 /// If there is a conflict for the update, then we need to perform a
588 /// copy-in/copy-out to preserve the original values of the array. If there is
589 /// no conflict, then it is save to eschew making any copies.
590 template <typename ArrayOp>
591 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
592 public:
593   explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
594                                      const ArrayCopyAnalysis &a,
595                                      const OperationUseMapT &m)
596       : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
597 
598   void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
599                     mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
600                     mlir::Type arrTy) const {
601     auto insPt = rewriter.saveInsertionPoint();
602     llvm::SmallVector<mlir::Value> indices;
603     llvm::SmallVector<mlir::Value> extents;
604     getExtents(extents, shapeOp);
605     // Build loop nest from column to row.
606     for (auto sh : llvm::reverse(extents)) {
607       auto idxTy = rewriter.getIndexType();
608       auto ubi = rewriter.create<fir::ConvertOp>(loc, idxTy, sh);
609       auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
610       auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
611       auto ub = rewriter.create<arith::SubIOp>(loc, idxTy, ubi, one);
612       auto loop = rewriter.create<fir::DoLoopOp>(loc, zero, ub, one);
613       rewriter.setInsertionPointToStart(loop.getBody());
614       indices.push_back(loop.getInductionVar());
615     }
616     // Reverse the indices so they are in column-major order.
617     std::reverse(indices.begin(), indices.end());
618     auto ty = getEleTy(arrTy);
619     auto fromAddr = rewriter.create<fir::ArrayCoorOp>(
620         loc, ty, src, shapeOp, mlir::Value{},
621         fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp,
622                                        indices),
623         mlir::ValueRange{});
624     auto load = rewriter.create<fir::LoadOp>(loc, fromAddr);
625     auto toAddr = rewriter.create<fir::ArrayCoorOp>(
626         loc, ty, dst, shapeOp, mlir::Value{},
627         fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp,
628                                        indices),
629         mlir::ValueRange{});
630     rewriter.create<fir::StoreOp>(loc, load, toAddr);
631     rewriter.restoreInsertionPoint(insPt);
632   }
633 
634   /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
635   /// temp and the LHS if the analysis found potential overlaps between the RHS
636   /// and LHS arrays. The element copy generator must be provided through \p
637   /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
638   /// Returns the address of the LHS element inside the loop and the LHS
639   /// ArrayLoad result.
640   std::pair<mlir::Value, mlir::Value>
641   materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
642                         ArrayOp update,
643                         llvm::function_ref<void(mlir::Value)> assignElement,
644                         mlir::Type lhsEltRefType) const {
645     auto *op = update.getOperation();
646     mlir::Operation *loadOp = useMap.lookup(op);
647     auto load = mlir::cast<ArrayLoadOp>(loadOp);
648     LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
649     if (analysis.hasPotentialConflict(loadOp)) {
650       // If there is a conflict between the arrays, then we copy the lhs array
651       // to a temporary, update the temporary, and copy the temporary back to
652       // the lhs array. This yields Fortran's copy-in copy-out array semantics.
653       LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
654       rewriter.setInsertionPoint(loadOp);
655       // Copy in.
656       llvm::SmallVector<mlir::Value> extents;
657       mlir::Value shapeOp =
658           getOrReadExtentsAndShapeOp(loc, rewriter, load, extents);
659       auto allocmem = rewriter.create<AllocMemOp>(
660           loc, dyn_cast_ptrOrBoxEleTy(load.getMemref().getType()),
661           load.getTypeparams(), extents);
662       genArrayCopy(load.getLoc(), rewriter, allocmem, load.getMemref(), shapeOp,
663                    load.getType());
664       rewriter.setInsertionPoint(op);
665       mlir::Value coor = genCoorOp(
666           rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
667           shapeOp, load.getSlice(), update.getIndices(), load.getTypeparams(),
668           update->hasAttr(fir::factory::attrFortranArrayOffsets()));
669       assignElement(coor);
670       mlir::Operation *storeOp = useMap.lookup(loadOp);
671       auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
672       rewriter.setInsertionPoint(storeOp);
673       // Copy out.
674       genArrayCopy(store.getLoc(), rewriter, store.getMemref(), allocmem,
675                    shapeOp, load.getType());
676       rewriter.create<FreeMemOp>(loc, allocmem);
677       return {coor, load.getResult()};
678     }
679     // Otherwise, when there is no conflict (a possible loop-carried
680     // dependence), the lhs array can be updated in place.
681     LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
682     rewriter.setInsertionPoint(op);
683     auto coorTy = getEleTy(load.getType());
684     mlir::Value coor = genCoorOp(
685         rewriter, loc, coorTy, lhsEltRefType, load.getMemref(), load.getShape(),
686         load.getSlice(), update.getIndices(), load.getTypeparams(),
687         update->hasAttr(fir::factory::attrFortranArrayOffsets()));
688     assignElement(coor);
689     return {coor, load.getResult()};
690   }
691 
692 private:
693   const ArrayCopyAnalysis &analysis;
694   const OperationUseMapT &useMap;
695 };
696 
697 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
698 public:
699   explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
700                                  const ArrayCopyAnalysis &a,
701                                  const OperationUseMapT &m)
702       : ArrayUpdateConversionBase{ctx, a, m} {}
703 
704   mlir::LogicalResult
705   matchAndRewrite(ArrayUpdateOp update,
706                   mlir::PatternRewriter &rewriter) const override {
707     auto loc = update.getLoc();
708     auto assignElement = [&](mlir::Value coor) {
709       rewriter.create<fir::StoreOp>(loc, update.getMerge(), coor);
710     };
711     auto lhsEltRefType = toRefType(update.getMerge().getType());
712     auto [_, lhsLoadResult] = materializeAssignment(
713         loc, rewriter, update, assignElement, lhsEltRefType);
714     update.replaceAllUsesWith(lhsLoadResult);
715     rewriter.replaceOp(update, lhsLoadResult);
716     return mlir::success();
717   }
718 };
719 
720 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
721 public:
722   explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
723                                  const ArrayCopyAnalysis &a,
724                                  const OperationUseMapT &m)
725       : ArrayUpdateConversionBase{ctx, a, m} {}
726 
727   mlir::LogicalResult
728   matchAndRewrite(ArrayModifyOp modify,
729                   mlir::PatternRewriter &rewriter) const override {
730     auto loc = modify.getLoc();
731     auto assignElement = [](mlir::Value) {
732       // Assignment already materialized by lowering using lhs element address.
733     };
734     auto lhsEltRefType = modify.getResult(0).getType();
735     auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
736         loc, rewriter, modify, assignElement, lhsEltRefType);
737     modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
738     rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
739     return mlir::success();
740   }
741 };
742 
743 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
744 public:
745   explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
746                                 const OperationUseMapT &m)
747       : OpRewritePattern{ctx}, useMap{m} {}
748 
749   mlir::LogicalResult
750   matchAndRewrite(ArrayFetchOp fetch,
751                   mlir::PatternRewriter &rewriter) const override {
752     auto *op = fetch.getOperation();
753     rewriter.setInsertionPoint(op);
754     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
755     auto loc = fetch.getLoc();
756     mlir::Value coor =
757         genCoorOp(rewriter, loc, getEleTy(load.getType()),
758                   toRefType(fetch.getType()), load.getMemref(), load.getShape(),
759                   load.getSlice(), fetch.getIndices(), load.getTypeparams(),
760                   fetch->hasAttr(fir::factory::attrFortranArrayOffsets()));
761     rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
762     return mlir::success();
763   }
764 
765 private:
766   const OperationUseMapT &useMap;
767 };
768 } // namespace
769 
770 namespace {
771 class ArrayValueCopyConverter
772     : public ArrayValueCopyBase<ArrayValueCopyConverter> {
773 public:
774   void runOnOperation() override {
775     auto func = getOperation();
776     LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
777                             << func.getName() << "'\n");
778     auto *context = &getContext();
779 
780     // Perform the conflict analysis.
781     auto &analysis = getAnalysis<ArrayCopyAnalysis>();
782     const auto &useMap = analysis.getUseMap();
783 
784     // Phase 1 is performing a rewrite on the array accesses. Once all the
785     // array accesses are rewritten we can go on phase 2.
786     // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in
787     // /copy-out refers the Fortran copy-in/copy-out semantics on statements.
788     mlir::RewritePatternSet patterns1(context);
789     patterns1.insert<ArrayFetchConversion>(context, useMap);
790     patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
791     patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
792     mlir::ConversionTarget target(*context);
793     target.addLegalDialect<
794         FIROpsDialect, mlir::scf::SCFDialect, mlir::arith::ArithmeticDialect,
795         mlir::cf::ControlFlowDialect, mlir::StandardOpsDialect>();
796     target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>();
797     // Rewrite the array fetch and array update ops.
798     if (mlir::failed(
799             mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
800       mlir::emitError(mlir::UnknownLoc::get(context),
801                       "failure in array-value-copy pass, phase 1");
802       signalPassFailure();
803     }
804 
805     mlir::RewritePatternSet patterns2(context);
806     patterns2.insert<ArrayLoadConversion>(context);
807     patterns2.insert<ArrayMergeStoreConversion>(context);
808     target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
809     if (mlir::failed(
810             mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
811       mlir::emitError(mlir::UnknownLoc::get(context),
812                       "failure in array-value-copy pass, phase 2");
813       signalPassFailure();
814     }
815   }
816 };
817 } // namespace
818 
819 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() {
820   return std::make_unique<ArrayValueCopyConverter>();
821 }
822