xref: /llvm-project/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (revision 3c700d131a35ce4b0063a4688dce4a0cb739ca83)
1 //===- OptimizedBufferization.cpp - special cases for bufferization -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // In some special cases we can bufferize hlfir expressions in a more optimal
9 // way so as to avoid creating temporaries. This pass handles these. It should
10 // be run before the catch-all bufferization pass.
11 //
12 // This requires constant subexpression elimination to have already been run.
13 //===----------------------------------------------------------------------===//
14 
15 #include "flang/Optimizer/Analysis/AliasAnalysis.h"
16 #include "flang/Optimizer/Builder/FIRBuilder.h"
17 #include "flang/Optimizer/Builder/HLFIRTools.h"
18 #include "flang/Optimizer/Dialect/FIROps.h"
19 #include "flang/Optimizer/Dialect/FIRType.h"
20 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
21 #include "flang/Optimizer/HLFIR/HLFIROps.h"
22 #include "flang/Optimizer/HLFIR/Passes.h"
23 #include "flang/Optimizer/OpenMP/Passes.h"
24 #include "flang/Optimizer/Transforms/Utils.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/IR/Dominance.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Interfaces/SideEffectInterfaces.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include <iterator>
34 #include <memory>
35 #include <mlir/Analysis/AliasAnalysis.h>
36 #include <optional>
37 
38 namespace hlfir {
39 #define GEN_PASS_DEF_OPTIMIZEDBUFFERIZATION
40 #include "flang/Optimizer/HLFIR/Passes.h.inc"
41 } // namespace hlfir
42 
43 #define DEBUG_TYPE "opt-bufferization"
44 
45 namespace {
46 
47 /// This transformation should match in place modification of arrays.
48 /// It should match code of the form
49 /// %array = some.operation // array has shape %shape
50 /// %expr = hlfir.elemental %shape : [...] {
51 /// bb0(%arg0: index)
52 ///   %0 = hlfir.designate %array(%arg0)
53 ///   [...] // no other reads or writes to %array
54 ///   hlfir.yield_element %element
55 /// }
56 /// hlfir.assign %expr to %array
57 /// hlfir.destroy %expr
58 ///
59 /// Or
60 ///
61 /// %read_array = some.operation // shape %shape
62 /// %expr = hlfir.elemental %shape : [...] {
63 /// bb0(%arg0: index)
64 ///   %0 = hlfir.designate %read_array(%arg0)
65 ///   [...]
66 ///   hlfir.yield_element %element
67 /// }
68 /// %write_array = some.operation // with shape %shape
69 /// [...] // operations which don't effect write_array
70 /// hlfir.assign %expr to %write_array
71 /// hlfir.destroy %expr
72 ///
73 /// In these cases, it is safe to turn the elemental into a do loop and modify
74 /// elements of %array in place without creating an extra temporary for the
75 /// elemental. We must check that there are no reads from the array at indexes
76 /// which might conflict with the assignment or any writes. For now we will keep
77 /// that strict and say that all reads must be at the elemental index (it is
78 /// probably safe to read from higher indices if lowering to an ordered loop).
79 class ElementalAssignBufferization
80     : public mlir::OpRewritePattern<hlfir::ElementalOp> {
81 private:
82   struct MatchInfo {
83     mlir::Value array;
84     hlfir::AssignOp assign;
85     hlfir::DestroyOp destroy;
86   };
87   /// determines if the transformation can be applied to this elemental
88   static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);
89 
90   /// Returns the array indices for the given hlfir.designate.
91   /// It recognizes the computations used to transform the one-based indices
92   /// into the array's lb-based indices, and returns the one-based indices
93   /// in these cases.
94   static llvm::SmallVector<mlir::Value>
95   getDesignatorIndices(hlfir::DesignateOp designate);
96 
97 public:
98   using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
99 
100   llvm::LogicalResult
101   matchAndRewrite(hlfir::ElementalOp elemental,
102                   mlir::PatternRewriter &rewriter) const override;
103 };
104 
105 /// recursively collect all effects between start and end (including start, not
106 /// including end) start must properly dominate end, start and end must be in
107 /// the same block. If any operations with unknown effects are found,
108 /// std::nullopt is returned
109 static std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
110 getEffectsBetween(mlir::Operation *start, mlir::Operation *end) {
111   mlir::SmallVector<mlir::MemoryEffects::EffectInstance> ret;
112   if (start == end)
113     return ret;
114   assert(start->getBlock() && end->getBlock() && "TODO: block arguments");
115   assert(start->getBlock() == end->getBlock());
116   assert(mlir::DominanceInfo{}.properlyDominates(start, end));
117 
118   mlir::Operation *nextOp = start;
119   while (nextOp && nextOp != end) {
120     std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
121         effects = mlir::getEffectsRecursively(nextOp);
122     if (!effects)
123       return std::nullopt;
124     ret.append(*effects);
125     nextOp = nextOp->getNextNode();
126   }
127   return ret;
128 }
129 
130 /// If effect is a read or write on val, return whether it aliases.
131 /// Otherwise return mlir::AliasResult::NoAlias
132 static mlir::AliasResult
133 containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect,
134                             mlir::Value val) {
135   fir::AliasAnalysis aliasAnalysis;
136 
137   if (mlir::isa<mlir::MemoryEffects::Read, mlir::MemoryEffects::Write>(
138           effect.getEffect())) {
139     mlir::Value accessedVal = effect.getValue();
140     if (mlir::isa<fir::DebuggingResource>(effect.getResource()))
141       return mlir::AliasResult::NoAlias;
142     if (!accessedVal)
143       return mlir::AliasResult::MayAlias;
144     if (accessedVal == val)
145       return mlir::AliasResult::MustAlias;
146 
147     // if the accessed value might alias val
148     mlir::AliasResult res = aliasAnalysis.alias(val, accessedVal);
149     if (!res.isNo())
150       return res;
151 
152     // FIXME: alias analysis of fir.load
153     // follow this common pattern:
154     // %ref = hlfir.designate %array(%index)
155     // %val = fir.load $ref
156     if (auto designate = accessedVal.getDefiningOp<hlfir::DesignateOp>()) {
157       if (designate.getMemref() == val)
158         return mlir::AliasResult::MustAlias;
159 
160       // if the designate is into an array that might alias val
161       res = aliasAnalysis.alias(val, designate.getMemref());
162       if (!res.isNo())
163         return res;
164     }
165   }
166   return mlir::AliasResult::NoAlias;
167 }
168 
169 // Helper class for analyzing two array slices represented
170 // by two hlfir.designate operations.
171 class ArraySectionAnalyzer {
172 public:
173   // The result of the analyzis is one of the values below.
174   enum class SlicesOverlapKind {
175     // Slices overlap is unknown.
176     Unknown,
177     // Slices are definitely identical.
178     DefinitelyIdentical,
179     // Slices are definitely disjoint.
180     DefinitelyDisjoint,
181     // Slices may be either disjoint or identical,
182     // i.e. there is definitely no partial overlap.
183     EitherIdenticalOrDisjoint
184   };
185 
186   // Analyzes two hlfir.designate results and returns the overlap kind.
187   // The callers may use this method when the alias analysis reports
188   // an alias of some kind, so that we can run Fortran specific analysis
189   // on the array slices to see if they are identical or disjoint.
190   // Note that the alias analysis are not able to give such an answer
191   // about the references.
192   static SlicesOverlapKind analyze(mlir::Value ref1, mlir::Value ref2);
193 
194 private:
195   struct SectionDesc {
196     // An array section is described by <lb, ub, stride> tuple.
197     // If the designator's subscript is not a triple, then
198     // the section descriptor is constructed as <lb, nullptr, nullptr>.
199     mlir::Value lb, ub, stride;
200 
201     SectionDesc(mlir::Value lb, mlir::Value ub, mlir::Value stride)
202         : lb(lb), ub(ub), stride(stride) {
203       assert(lb && "lower bound or index must be specified");
204       normalize();
205     }
206 
207     // Normalize the section descriptor:
208     //   1. If UB is nullptr, then it is set to LB.
209     //   2. If LB==UB, then stride does not matter,
210     //      so it is reset to nullptr.
211     //   3. If STRIDE==1, then it is reset to nullptr.
212     void normalize() {
213       if (!ub)
214         ub = lb;
215       if (lb == ub)
216         stride = nullptr;
217       if (stride)
218         if (auto val = fir::getIntIfConstant(stride))
219           if (*val == 1)
220             stride = nullptr;
221     }
222 
223     bool operator==(const SectionDesc &other) const {
224       return lb == other.lb && ub == other.ub && stride == other.stride;
225     }
226   };
227 
228   // Given an operand_iterator over the indices operands,
229   // read the subscript values and return them as SectionDesc
230   // updating the iterator. If isTriplet is true,
231   // the subscript is a triplet, and the result is <lb, ub, stride>.
232   // Otherwise, the subscript is a scalar index, and the result
233   // is <index, nullptr, nullptr>.
234   static SectionDesc readSectionDesc(mlir::Operation::operand_iterator &it,
235                                      bool isTriplet) {
236     if (isTriplet)
237       return {*it++, *it++, *it++};
238     return {*it++, nullptr, nullptr};
239   }
240 
241   // Return the ordered lower and upper bounds of the section.
242   // If stride is known to be non-negative, then the ordered
243   // bounds match the <lb, ub> of the descriptor.
244   // If stride is known to be negative, then the ordered
245   // bounds are <ub, lb> of the descriptor.
246   // If stride is unknown, we cannot deduce any order,
247   // so the result is <nullptr, nullptr>
248   static std::pair<mlir::Value, mlir::Value>
249   getOrderedBounds(const SectionDesc &desc) {
250     mlir::Value stride = desc.stride;
251     // Null stride means stride=1.
252     if (!stride)
253       return {desc.lb, desc.ub};
254     // Reverse the bounds, if stride is negative.
255     if (auto val = fir::getIntIfConstant(stride)) {
256       if (*val >= 0)
257         return {desc.lb, desc.ub};
258       else
259         return {desc.ub, desc.lb};
260     }
261 
262     return {nullptr, nullptr};
263   }
264 
265   // Given two array sections <lb1, ub1, stride1> and
266   // <lb2, ub2, stride2>, return true only if the sections
267   // are known to be disjoint.
268   //
269   // For example, for any positive constant C:
270   //   X:Y does not overlap with (Y+C):Z
271   //   X:Y does not overlap with Z:(X-C)
272   static bool areDisjointSections(const SectionDesc &desc1,
273                                   const SectionDesc &desc2) {
274     auto [lb1, ub1] = getOrderedBounds(desc1);
275     auto [lb2, ub2] = getOrderedBounds(desc2);
276     if (!lb1 || !lb2)
277       return false;
278     // Note that this comparison must be made on the ordered bounds,
279     // otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated
280     // as not overlapping (x=2, y=10, z=9).
281     if (isLess(ub1, lb2) || isLess(ub2, lb1))
282       return true;
283     return false;
284   }
285 
286   // Given two array sections <lb1, ub1, stride1> and
287   // <lb2, ub2, stride2>, return true only if the sections
288   // are known to be identical.
289   //
290   // For example:
291   //   <x, x, stride>
292   //   <x, nullptr, nullptr>
293   //
294   // These sections are identical, from the point of which array
295   // elements are being addresses, even though the shape
296   // of the array slices might be different.
297   static bool areIdenticalSections(const SectionDesc &desc1,
298                                    const SectionDesc &desc2) {
299     if (desc1 == desc2)
300       return true;
301     return false;
302   }
303 
304   // Return true, if v1 is known to be less than v2.
305   static bool isLess(mlir::Value v1, mlir::Value v2);
306 };
307 
308 ArraySectionAnalyzer::SlicesOverlapKind
309 ArraySectionAnalyzer::analyze(mlir::Value ref1, mlir::Value ref2) {
310   if (ref1 == ref2)
311     return SlicesOverlapKind::DefinitelyIdentical;
312 
313   auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>();
314   auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>();
315   // We only support a pair of designators right now.
316   if (!des1 || !des2)
317     return SlicesOverlapKind::Unknown;
318 
319   if (des1.getMemref() != des2.getMemref()) {
320     // If the bases are different, then there is unknown overlap.
321     LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n"
322                             << des1 << "and:\n"
323                             << des2 << "\n");
324     return SlicesOverlapKind::Unknown;
325   }
326 
327   // Require all components of the designators to be the same.
328   // It might be too strict, e.g. we may probably allow for
329   // different type parameters.
330   if (des1.getComponent() != des2.getComponent() ||
331       des1.getComponentShape() != des2.getComponentShape() ||
332       des1.getSubstring() != des2.getSubstring() ||
333       des1.getComplexPart() != des2.getComplexPart() ||
334       des1.getTypeparams() != des2.getTypeparams()) {
335     LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n"
336                             << des1 << "and:\n"
337                             << des2 << "\n");
338     return SlicesOverlapKind::Unknown;
339   }
340 
341   // Analyze the subscripts.
342   auto des1It = des1.getIndices().begin();
343   auto des2It = des2.getIndices().begin();
344   bool identicalTriplets = true;
345   bool identicalIndices = true;
346   for (auto [isTriplet1, isTriplet2] :
347        llvm::zip(des1.getIsTriplet(), des2.getIsTriplet())) {
348     SectionDesc desc1 = readSectionDesc(des1It, isTriplet1);
349     SectionDesc desc2 = readSectionDesc(des2It, isTriplet2);
350 
351     // See if we can prove that any of the sections do not overlap.
352     // This is mostly a Polyhedron/nf performance hack that looks for
353     // particular relations between the lower and upper bounds
354     // of the array sections, e.g. for any positive constant C:
355     //   X:Y does not overlap with (Y+C):Z
356     //   X:Y does not overlap with Z:(X-C)
357     if (areDisjointSections(desc1, desc2))
358       return SlicesOverlapKind::DefinitelyDisjoint;
359 
360     if (!areIdenticalSections(desc1, desc2)) {
361       if (isTriplet1 || isTriplet2) {
362         // For example:
363         //   hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0)
364         //   hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1)
365         //
366         // If all the triplets (section speficiers) are the same, then
367         // we do not care if %0 is equal to %1 - the slices are either
368         // identical or completely disjoint.
369         //
370         // Also, treat these as identical sections:
371         //   hlfir.designate %6#0 (%c2:%c2:%c1)
372         //   hlfir.designate %6#0 (%c2)
373         identicalTriplets = false;
374         LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
375                                 << des1 << "and:\n"
376                                 << des2 << "\n");
377       } else {
378         identicalIndices = false;
379         LLVM_DEBUG(llvm::dbgs() << "Indices mismatch for:\n"
380                                 << des1 << "and:\n"
381                                 << des2 << "\n");
382       }
383     }
384   }
385 
386   if (identicalTriplets) {
387     if (identicalIndices)
388       return SlicesOverlapKind::DefinitelyIdentical;
389     else
390       return SlicesOverlapKind::EitherIdenticalOrDisjoint;
391   }
392 
393   LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n"
394                           << des1 << "and:\n"
395                           << des2 << "\n");
396   return SlicesOverlapKind::Unknown;
397 }
398 
399 bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
400   auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
401     auto *op = v.getDefiningOp();
402     while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
403       op = conv.getValue().getDefiningOp();
404     return op;
405   };
406 
407   auto isPositiveConstant = [](mlir::Value v) -> bool {
408     if (auto val = fir::getIntIfConstant(v))
409       return *val > 0;
410     return false;
411   };
412 
413   auto *op1 = removeConvert(v1);
414   auto *op2 = removeConvert(v2);
415   if (!op1 || !op2)
416     return false;
417 
418   // Check if they are both constants.
419   if (auto val1 = fir::getIntIfConstant(op1->getResult(0)))
420     if (auto val2 = fir::getIntIfConstant(op2->getResult(0)))
421       return *val1 < *val2;
422 
423   // Handle some variable cases (C > 0):
424   //   v2 = v1 + C
425   //   v2 = C + v1
426   //   v1 = v2 - C
427   if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
428     if ((addi.getLhs().getDefiningOp() == op1 &&
429          isPositiveConstant(addi.getRhs())) ||
430         (addi.getRhs().getDefiningOp() == op1 &&
431          isPositiveConstant(addi.getLhs())))
432       return true;
433   if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
434     if (subi.getLhs().getDefiningOp() == op2 &&
435         isPositiveConstant(subi.getRhs()))
436       return true;
437   return false;
438 }
439 
440 llvm::SmallVector<mlir::Value>
441 ElementalAssignBufferization::getDesignatorIndices(
442     hlfir::DesignateOp designate) {
443   mlir::Value memref = designate.getMemref();
444 
445   // If the object is a box, then the indices may be adjusted
446   // according to the box's lower bound(s). Scan through
447   // the computations to try to find the one-based indices.
448   if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
449     // Look for the following pattern:
450     //   %13 = fir.load %12 : !fir.ref<!fir.box<...>
451     //   %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
452     //   %17 = arith.subi %14#0, %c1 : index
453     //   %18 = arith.addi %arg2, %17 : index
454     //   %19 = hlfir.designate %13 (%18)  : (!fir.box<...>, index) -> ...
455     //
456     // %arg2 is a one-based index.
457 
458     auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
459       // Return true, if v and dim are such that:
460       //   %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
461       //   %17 = arith.subi %14#0, %c1 : index
462       //   %19 = hlfir.designate %13 (...)  : (!fir.box<...>, index) -> ...
463       if (auto subOp =
464               mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
465         auto cst = fir::getIntIfConstant(subOp.getRhs());
466         if (!cst || *cst != 1)
467           return false;
468         if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
469                 subOp.getLhs().getDefiningOp())) {
470           if (memref != dimsOp.getVal() ||
471               dimsOp.getResult(0) != subOp.getLhs())
472             return false;
473           auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
474           return dimsOpDim && dimsOpDim == dim;
475         }
476       }
477       return false;
478     };
479 
480     llvm::SmallVector<mlir::Value> newIndices;
481     for (auto index : llvm::enumerate(designate.getIndices())) {
482       if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
483               index.value().getDefiningOp())) {
484         for (unsigned opNum = 0; opNum < 2; ++opNum)
485           if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
486             newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
487             break;
488           }
489 
490         // If new one-based index was not added, exit early.
491         if (newIndices.size() <= index.index())
492           break;
493       }
494     }
495 
496     // If any of the indices is not adjusted to the array's lb,
497     // then return the original designator indices.
498     if (newIndices.size() != designate.getIndices().size())
499       return designate.getIndices();
500 
501     return newIndices;
502   }
503 
504   return designate.getIndices();
505 }
506 
507 std::optional<ElementalAssignBufferization::MatchInfo>
508 ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
509   mlir::Operation::user_range users = elemental->getUsers();
510   // the only uses of the elemental should be the assignment and the destroy
511   if (std::distance(users.begin(), users.end()) != 2) {
512     LLVM_DEBUG(llvm::dbgs() << "Too many uses of the elemental\n");
513     return std::nullopt;
514   }
515 
516   // If the ElementalOp must produce a temporary (e.g. for
517   // finalization purposes), then we cannot inline it.
518   if (hlfir::elementalOpMustProduceTemp(elemental)) {
519     LLVM_DEBUG(llvm::dbgs() << "ElementalOp must produce a temp\n");
520     return std::nullopt;
521   }
522 
523   MatchInfo match;
524   for (mlir::Operation *user : users)
525     mlir::TypeSwitch<mlir::Operation *, void>(user)
526         .Case([&](hlfir::AssignOp op) { match.assign = op; })
527         .Case([&](hlfir::DestroyOp op) { match.destroy = op; });
528 
529   if (!match.assign || !match.destroy) {
530     LLVM_DEBUG(llvm::dbgs() << "Couldn't find assign or destroy\n");
531     return std::nullopt;
532   }
533 
534   // the array is what the elemental is assigned into
535   // TODO: this could be extended to also allow hlfir.expr by first bufferizing
536   // the incoming expression
537   match.array = match.assign.getLhs();
538   mlir::Type arrayType = mlir::dyn_cast<fir::SequenceType>(
539       fir::unwrapPassByRefType(match.array.getType()));
540   if (!arrayType) {
541     LLVM_DEBUG(llvm::dbgs() << "AssignOp's result is not an array\n");
542     return std::nullopt;
543   }
544 
545   // require that the array elements are trivial
546   // TODO: this is just to make the pass easier to think about. Not an inherent
547   // limitation
548   mlir::Type eleTy = hlfir::getFortranElementType(arrayType);
549   if (!fir::isa_trivial(eleTy)) {
550     LLVM_DEBUG(llvm::dbgs() << "AssignOp's data type is not trivial\n");
551     return std::nullopt;
552   }
553 
554   // The array must have the same shape as the elemental.
555   //
556   // f2018 10.2.1.2 (3) requires the lhs and rhs of an assignment to be
557   // conformable unless the lhs is an allocatable array. In HLFIR we can
558   // see this from the presence or absence of the realloc attribute on
559   // hlfir.assign. If it is not a realloc assignment, we can trust that
560   // the shapes do conform.
561   //
562   // TODO: the lhs's shape is dynamic, so it is hard to prove that
563   // there is no reallocation of the lhs due to the assignment.
564   // We can probably try generating multiple versions of the code
565   // with checking for the shape match, length parameters match, etc.
566   if (match.assign.isAllocatableAssignment()) {
567     LLVM_DEBUG(llvm::dbgs() << "AssignOp may involve (re)allocation of LHS\n");
568     return std::nullopt;
569   }
570 
571   // the transformation wants to apply the elemental in a do-loop at the
572   // hlfir.assign, check there are no effects which make this unsafe
573 
574   // keep track of any values written to in the elemental, as these can't be
575   // read from between the elemental and the assignment
576   // likewise, values read in the elemental cannot be written to between the
577   // elemental and the assign
578   mlir::SmallVector<mlir::Value, 1> notToBeAccessedBeforeAssign;
579   // any accesses to the array between the array and the assignment means it
580   // would be unsafe to move the elemental to the assignment
581   notToBeAccessedBeforeAssign.push_back(match.array);
582 
583   // 1) side effects in the elemental body - it isn't sufficient to just look
584   // for ordered elementals because we also cannot support out of order reads
585   std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
586       effects = getEffectsBetween(&elemental.getBody()->front(),
587                                   elemental.getBody()->getTerminator());
588   if (!effects) {
589     LLVM_DEBUG(llvm::dbgs()
590                << "operation with unknown effects inside elemental\n");
591     return std::nullopt;
592   }
593   for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
594     mlir::AliasResult res = containsReadOrWriteEffectOn(effect, match.array);
595     if (res.isNo()) {
596       if (mlir::isa<mlir::MemoryEffects::Write, mlir::MemoryEffects::Read>(
597               effect.getEffect()))
598         if (effect.getValue())
599           notToBeAccessedBeforeAssign.push_back(effect.getValue());
600 
601       // this is safe in the elemental
602       continue;
603     }
604 
605     // don't allow any aliasing writes in the elemental
606     if (mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect())) {
607       LLVM_DEBUG(llvm::dbgs() << "write inside the elemental body\n");
608       return std::nullopt;
609     }
610 
611     // allow if and only if the reads are from the elemental indices, in order
612     // => each iteration doesn't read values written by other iterations
613     // don't allow reads from a different value which may alias: fir alias
614     // analysis isn't precise enough to tell us if two aliasing arrays overlap
615     // exactly or only partially. If they overlap partially, a designate at the
616     // elemental indices could be accessing different elements: e.g. we could
617     // designate two slices of the same array at different start indexes. These
618     // two MustAlias but index 1 of one array isn't the same element as index 1
619     // of the other array.
620     if (!res.isPartial()) {
621       if (auto designate =
622               effect.getValue().getDefiningOp<hlfir::DesignateOp>()) {
623         ArraySectionAnalyzer::SlicesOverlapKind overlap =
624             ArraySectionAnalyzer::analyze(match.array, designate.getMemref());
625         if (overlap ==
626             ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint)
627           continue;
628 
629         if (overlap == ArraySectionAnalyzer::SlicesOverlapKind::Unknown) {
630           LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
631                                   << " at " << elemental.getLoc() << "\n");
632           return std::nullopt;
633         }
634         auto indices = getDesignatorIndices(designate);
635         auto elementalIndices = elemental.getIndices();
636         if (indices.size() == elementalIndices.size() &&
637             std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
638                        elementalIndices.end()))
639           continue;
640 
641         LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
642                                 << " at " << elemental.getLoc() << "\n");
643         return std::nullopt;
644       }
645     }
646     LLVM_DEBUG(llvm::dbgs() << "disallowed side-effect: " << effect.getValue()
647                             << " for " << elemental.getLoc() << "\n");
648     return std::nullopt;
649   }
650 
651   // 2) look for conflicting effects between the elemental and the assignment
652   effects = getEffectsBetween(elemental->getNextNode(), match.assign);
653   if (!effects) {
654     LLVM_DEBUG(
655         llvm::dbgs()
656         << "operation with unknown effects between elemental and assign\n");
657     return std::nullopt;
658   }
659   for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
660     // not safe to access anything written in the elemental as this write
661     // will be moved to the assignment
662     for (mlir::Value val : notToBeAccessedBeforeAssign) {
663       mlir::AliasResult res = containsReadOrWriteEffectOn(effect, val);
664       if (!res.isNo()) {
665         LLVM_DEBUG(llvm::dbgs()
666                    << "diasllowed side-effect: " << effect.getValue() << " for "
667                    << elemental.getLoc() << "\n");
668         return std::nullopt;
669       }
670     }
671   }
672 
673   return match;
674 }
675 
676 llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
677     hlfir::ElementalOp elemental, mlir::PatternRewriter &rewriter) const {
678   std::optional<MatchInfo> match = findMatch(elemental);
679   if (!match)
680     return rewriter.notifyMatchFailure(
681         elemental, "cannot prove safety of ElementalAssignBufferization");
682 
683   mlir::Location loc = elemental->getLoc();
684   fir::FirOpBuilder builder(rewriter, elemental.getOperation());
685   auto extents = hlfir::getIndexExtents(loc, builder, elemental.getShape());
686 
687   // create the loop at the assignment
688   builder.setInsertionPoint(match->assign);
689 
690   // Generate a loop nest looping around the hlfir.elemental shape and clone
691   // hlfir.elemental region inside the inner loop
692   hlfir::LoopNest loopNest =
693       hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
694                          flangomp::shouldUseWorkshareLowering(elemental));
695   builder.setInsertionPointToStart(loopNest.body);
696   auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
697                                         loopNest.oneBasedIndices);
698   hlfir::Entity elementValue{yield.getElementValue()};
699   rewriter.eraseOp(yield);
700 
701   // Assign the element value to the array element for this iteration.
702   auto arrayElement = hlfir::getElementAt(
703       loc, builder, hlfir::Entity{match->array}, loopNest.oneBasedIndices);
704   builder.create<hlfir::AssignOp>(
705       loc, elementValue, arrayElement, /*realloc=*/false,
706       /*keep_lhs_length_if_realloc=*/false, match->assign.getTemporaryLhs());
707 
708   rewriter.eraseOp(match->assign);
709   rewriter.eraseOp(match->destroy);
710   rewriter.eraseOp(elemental);
711   return mlir::success();
712 }
713 
714 /// Expand hlfir.assign of a scalar RHS to array LHS into a loop nest
715 /// of element-by-element assignments:
716 ///   hlfir.assign %cst to %0 : f32, !fir.ref<!fir.array<6x6xf32>>
717 /// into:
718 ///   fir.do_loop %arg0 = %c1 to %c6 step %c1 unordered {
719 ///     fir.do_loop %arg1 = %c1 to %c6 step %c1 unordered {
720 ///       %1 = hlfir.designate %0 (%arg1, %arg0)  :
721 ///       (!fir.ref<!fir.array<6x6xf32>>, index, index) -> !fir.ref<f32>
722 ///       hlfir.assign %cst to %1 : f32, !fir.ref<f32>
723 ///     }
724 ///   }
725 class BroadcastAssignBufferization
726     : public mlir::OpRewritePattern<hlfir::AssignOp> {
727 private:
728 public:
729   using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
730 
731   llvm::LogicalResult
732   matchAndRewrite(hlfir::AssignOp assign,
733                   mlir::PatternRewriter &rewriter) const override;
734 };
735 
736 llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
737     hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const {
738   // Since RHS is a scalar and LHS is an array, LHS must be allocated
739   // in a conforming Fortran program, and LHS cannot be reallocated
740   // as a result of the assignment. So we can ignore isAllocatableAssignment
741   // and do the transformation always.
742   mlir::Value rhs = assign.getRhs();
743   if (!fir::isa_trivial(rhs.getType()))
744     return rewriter.notifyMatchFailure(
745         assign, "AssignOp's RHS is not a trivial scalar");
746 
747   hlfir::Entity lhs{assign.getLhs()};
748   if (!lhs.isArray())
749     return rewriter.notifyMatchFailure(assign,
750                                        "AssignOp's LHS is not an array");
751 
752   mlir::Type eleTy = lhs.getFortranElementType();
753   if (!fir::isa_trivial(eleTy))
754     return rewriter.notifyMatchFailure(
755         assign, "AssignOp's LHS data type is not trivial");
756 
757   mlir::Location loc = assign->getLoc();
758   fir::FirOpBuilder builder(rewriter, assign.getOperation());
759   builder.setInsertionPoint(assign);
760   lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
761   mlir::Value shape = hlfir::genShape(loc, builder, lhs);
762   llvm::SmallVector<mlir::Value> extents =
763       hlfir::getIndexExtents(loc, builder, shape);
764   hlfir::LoopNest loopNest =
765       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
766                          flangomp::shouldUseWorkshareLowering(assign));
767   builder.setInsertionPointToStart(loopNest.body);
768   auto arrayElement =
769       hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
770   builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
771   rewriter.eraseOp(assign);
772   return mlir::success();
773 }
774 
775 using GenBodyFn =
776     std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
777                               const llvm::SmallVectorImpl<mlir::Value> &)>;
778 static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
779                                          mlir::Location loc, mlir::Value init,
780                                          mlir::Value shape, GenBodyFn genBody) {
781   auto extents = hlfir::getIndexExtents(loc, builder, shape);
782   mlir::Value reduction = init;
783   mlir::IndexType idxTy = builder.getIndexType();
784   mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
785 
786   // Create a reduction loop nest. We use one-based indices so that they can be
787   // passed to the elemental, and reverse the order so that they can be
788   // generated in column-major order for better performance.
789   llvm::SmallVector<mlir::Value> indices(extents.size(), mlir::Value{});
790   for (unsigned i = 0; i < extents.size(); ++i) {
791     auto loop = builder.create<fir::DoLoopOp>(
792         loc, oneIdx, extents[extents.size() - i - 1], oneIdx, false,
793         /*finalCountValue=*/false, reduction);
794     reduction = loop.getRegionIterArgs()[0];
795     indices[extents.size() - i - 1] = loop.getInductionVar();
796     // Set insertion point to the loop body so that the next loop
797     // is inserted inside the current one.
798     builder.setInsertionPointToStart(loop.getBody());
799   }
800 
801   // Generate the body
802   reduction = genBody(builder, loc, reduction, indices);
803 
804   // Unwind the loop nest.
805   for (unsigned i = 0; i < extents.size(); ++i) {
806     auto result = builder.create<fir::ResultOp>(loc, reduction);
807     auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
808     reduction = loop.getResult(0);
809     // Set insertion point after the loop operation that we have
810     // just processed.
811     builder.setInsertionPointAfter(loop.getOperation());
812   }
813 
814   return reduction;
815 }
816 
817 auto makeMinMaxInitValGenerator(bool isMax) {
818   return [isMax](fir::FirOpBuilder builder, mlir::Location loc,
819                  mlir::Type elementType) -> mlir::Value {
820     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
821       const llvm::fltSemantics &sem = ty.getFloatSemantics();
822       llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
823       return builder.createRealConstant(loc, elementType, limit);
824     }
825     unsigned bits = elementType.getIntOrFloatBitWidth();
826     int64_t limitInt =
827         isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
828               : llvm::APInt::getSignedMaxValue(bits).getSExtValue();
829     return builder.createIntegerConstant(loc, elementType, limitInt);
830   };
831 }
832 
833 mlir::Value generateMinMaxComparison(fir::FirOpBuilder builder,
834                                      mlir::Location loc, mlir::Value elem,
835                                      mlir::Value reduction, bool isMax) {
836   if (mlir::isa<mlir::FloatType>(reduction.getType())) {
837     // For FP reductions we want the first smallest value to be used, that
838     // is not NaN. A OGL/OLT condition will usually work for this unless all
839     // the values are Nan or Inf. This follows the same logic as
840     // NumericCompare for Minloc/Maxlox in extrema.cpp.
841     mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
842         loc,
843         isMax ? mlir::arith::CmpFPredicate::OGT
844               : mlir::arith::CmpFPredicate::OLT,
845         elem, reduction);
846     mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
847         loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
848     mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
849         loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
850     cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
851     return builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
852   } else if (mlir::isa<mlir::IntegerType>(reduction.getType())) {
853     return builder.create<mlir::arith::CmpIOp>(
854         loc,
855         isMax ? mlir::arith::CmpIPredicate::sgt
856               : mlir::arith::CmpIPredicate::slt,
857         elem, reduction);
858   }
859   llvm_unreachable("unsupported type");
860 }
861 
862 /// Given a reduction operation with an elemental/designate source, attempt to
863 /// generate a do-loop to perform the operation inline.
864 ///   %e = hlfir.elemental %shape unordered
865 ///   %r = hlfir.count %e
866 /// =>
867 ///   %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
868 ///     %i = <inline elemental>
869 ///     %c = <reduce count> %i
870 ///     fir.result %c
871 template <typename Op>
872 class ReductionConversion : public mlir::OpRewritePattern<Op> {
873 public:
874   using mlir::OpRewritePattern<Op>::OpRewritePattern;
875 
876   llvm::LogicalResult
877   matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
878     mlir::Location loc = op.getLoc();
879     // Select source and validate its arguments.
880     mlir::Value source;
881     bool valid = false;
882     if constexpr (std::is_same_v<Op, hlfir::AnyOp> ||
883                   std::is_same_v<Op, hlfir::AllOp> ||
884                   std::is_same_v<Op, hlfir::CountOp>) {
885       source = op.getMask();
886       valid = !op.getDim();
887     } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
888                          std::is_same_v<Op, hlfir::MinvalOp>) {
889       source = op.getArray();
890       valid = !op.getDim() && !op.getMask();
891     } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
892                          std::is_same_v<Op, hlfir::MinlocOp>) {
893       source = op.getArray();
894       valid = !op.getDim() && !op.getMask() && !op.getBack();
895     }
896     if (!valid)
897       return rewriter.notifyMatchFailure(
898           op, "Currently does not accept optional arguments");
899 
900     hlfir::ElementalOp elemental;
901     hlfir::DesignateOp designate;
902     mlir::Value shape;
903     if ((elemental = source.template getDefiningOp<hlfir::ElementalOp>())) {
904       shape = elemental.getOperand(0);
905     } else if ((designate =
906                     source.template getDefiningOp<hlfir::DesignateOp>())) {
907       shape = designate.getShape();
908     } else {
909       return rewriter.notifyMatchFailure(op, "Did not find valid argument");
910     }
911 
912     auto inlineSource =
913         [elemental, &designate](
914             fir::FirOpBuilder builder, mlir::Location loc,
915             const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
916       if (elemental) {
917         // Inline the elemental and get the value from it.
918         auto yield = inlineElementalOp(loc, builder, elemental, indices);
919         auto tmp = yield.getElementValue();
920         yield->erase();
921         return tmp;
922       }
923       if (designate) {
924         // Create a designator over designator, then load the reference.
925         auto resEntity = hlfir::Entity{designate.getResult()};
926         auto tmp = builder.create<hlfir::DesignateOp>(
927             loc, getVariableElementType(resEntity), designate, indices);
928         return builder.create<fir::LoadOp>(loc, tmp);
929       }
930       llvm_unreachable("unsupported type");
931     };
932 
933     fir::KindMapping kindMap =
934         fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
935     fir::FirOpBuilder builder{op, kindMap};
936 
937     mlir::Value init;
938     GenBodyFn genBodyFn;
939     if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
940       init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
941       genBodyFn =
942           [inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
943                          mlir::Value reduction,
944                          const llvm::SmallVectorImpl<mlir::Value> &indices)
945           -> mlir::Value {
946         // Conditionally set the reduction variable.
947         mlir::Value cond = builder.create<fir::ConvertOp>(
948             loc, builder.getI1Type(), inlineSource(builder, loc, indices));
949         return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
950       };
951     } else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
952       init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
953       genBodyFn =
954           [inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
955                          mlir::Value reduction,
956                          const llvm::SmallVectorImpl<mlir::Value> &indices)
957           -> mlir::Value {
958         // Conditionally set the reduction variable.
959         mlir::Value cond = builder.create<fir::ConvertOp>(
960             loc, builder.getI1Type(), inlineSource(builder, loc, indices));
961         return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
962       };
963     } else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
964       init = builder.createIntegerConstant(loc, op.getType(), 0);
965       genBodyFn =
966           [inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
967                          mlir::Value reduction,
968                          const llvm::SmallVectorImpl<mlir::Value> &indices)
969           -> mlir::Value {
970         // Conditionally add one to the current value
971         mlir::Value cond = builder.create<fir::ConvertOp>(
972             loc, builder.getI1Type(), inlineSource(builder, loc, indices));
973         mlir::Value one =
974             builder.createIntegerConstant(loc, reduction.getType(), 1);
975         mlir::Value add1 =
976             builder.create<mlir::arith::AddIOp>(loc, reduction, one);
977         return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
978                                                      reduction);
979       };
980     } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
981                          std::is_same_v<Op, hlfir::MinlocOp>) {
982       // TODO: implement minloc/maxloc conversion.
983       return rewriter.notifyMatchFailure(
984           op, "Currently minloc/maxloc is not handled");
985     } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
986                          std::is_same_v<Op, hlfir::MinvalOp>) {
987       bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
988       init = makeMinMaxInitValGenerator(isMax)(builder, loc, op.getType());
989       genBodyFn = [inlineSource,
990                    isMax](fir::FirOpBuilder builder, mlir::Location loc,
991                           mlir::Value reduction,
992                           const llvm::SmallVectorImpl<mlir::Value> &indices)
993           -> mlir::Value {
994         mlir::Value val = inlineSource(builder, loc, indices);
995         mlir::Value cmp =
996             generateMinMaxComparison(builder, loc, val, reduction, isMax);
997         return builder.create<mlir::arith::SelectOp>(loc, cmp, val, reduction);
998       };
999     } else {
1000       llvm_unreachable("unsupported type");
1001     }
1002 
1003     mlir::Value res =
1004         generateReductionLoop(builder, loc, init, shape, genBodyFn);
1005     if (res.getType() != op.getType())
1006       res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
1007 
1008     // Check if the op was the only user of the source (apart from a destroy),
1009     // and remove it if so.
1010     mlir::Operation *sourceOp = source.getDefiningOp();
1011     mlir::Operation::user_range srcUsers = sourceOp->getUsers();
1012     hlfir::DestroyOp srcDestroy;
1013     if (std::distance(srcUsers.begin(), srcUsers.end()) == 2) {
1014       srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*srcUsers.begin());
1015       if (!srcDestroy)
1016         srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++srcUsers.begin());
1017     }
1018 
1019     rewriter.replaceOp(op, res);
1020     if (srcDestroy) {
1021       rewriter.eraseOp(srcDestroy);
1022       rewriter.eraseOp(sourceOp);
1023     }
1024     return mlir::success();
1025   }
1026 };
1027 
1028 // Look for minloc(mask=elemental) and generate the minloc loop with
1029 // inlined elemental.
1030 //  %e = hlfir.elemental %shape ({ ... })
1031 //  %m = hlfir.minloc %array mask %e
1032 template <typename Op>
1033 class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
1034 public:
1035   using mlir::OpRewritePattern<Op>::OpRewritePattern;
1036 
1037   llvm::LogicalResult
1038   matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
1039     if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
1040       return rewriter.notifyMatchFailure(mloc,
1041                                          "Did not find valid minloc/maxloc");
1042 
1043     bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;
1044 
1045     auto elemental =
1046         mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
1047     if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
1048       return rewriter.notifyMatchFailure(mloc, "Did not find elemental");
1049 
1050     mlir::Value array = mloc.getArray();
1051 
1052     unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
1053     mlir::Type arrayType = array.getType();
1054     if (!mlir::isa<fir::BoxType>(arrayType))
1055       return rewriter.notifyMatchFailure(
1056           mloc, "Currently requires a boxed type input");
1057     mlir::Type elementType = hlfir::getFortranElementType(arrayType);
1058     if (!fir::isa_trivial(elementType))
1059       return rewriter.notifyMatchFailure(
1060           mloc, "Character arrays are currently not handled");
1061 
1062     mlir::Location loc = mloc.getLoc();
1063     fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
1064     mlir::Value resultArr = builder.createTemporary(
1065         loc, fir::SequenceType::get(
1066                  rank, hlfir::getFortranElementType(mloc.getType())));
1067 
1068     auto init = makeMinMaxInitValGenerator(isMax);
1069 
1070     auto genBodyOp =
1071         [&rank, &resultArr, &elemental, isMax](
1072             fir::FirOpBuilder builder, mlir::Location loc,
1073             mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
1074             mlir::Value reduction,
1075             const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
1076       // We are in the innermost loop: generate the elemental inline
1077       mlir::Value oneIdx =
1078           builder.createIntegerConstant(loc, builder.getIndexType(), 1);
1079       llvm::SmallVector<mlir::Value> oneBasedIndices;
1080       llvm::transform(
1081           indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) {
1082             return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx);
1083           });
1084       hlfir::YieldElementOp yield =
1085           hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices);
1086       mlir::Value maskElem = yield.getElementValue();
1087       yield->erase();
1088 
1089       mlir::Type ifCompatType = builder.getI1Type();
1090       mlir::Value ifCompatElem =
1091           builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
1092 
1093       llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
1094       fir::IfOp maskIfOp =
1095           builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
1096                                     /*withElseRegion=*/true);
1097       builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front());
1098 
1099       // Set flag that mask was true at some point
1100       mlir::Value flagSet = builder.createIntegerConstant(
1101           loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
1102       mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
1103       mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array},
1104                                              oneBasedIndices);
1105       mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
1106 
1107       // Compare with the max reduction value
1108       mlir::Value cmp =
1109           generateMinMaxComparison(builder, loc, elem, reduction, isMax);
1110 
1111       // The condition used for the loop is isFirst || <the condition above>.
1112       isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
1113       isFirst = builder.create<mlir::arith::XOrIOp>(
1114           loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
1115       cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
1116 
1117       // Set the new coordinate to the result
1118       fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
1119                                                  /*withElseRegion*/ true);
1120 
1121       builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1122       builder.create<fir::StoreOp>(loc, flagSet, flagRef);
1123       mlir::Type resultElemTy =
1124           hlfir::getFortranElementType(resultArr.getType());
1125       mlir::Type returnRefTy = builder.getRefType(resultElemTy);
1126       mlir::IndexType idxTy = builder.getIndexType();
1127 
1128       for (unsigned int i = 0; i < rank; ++i) {
1129         mlir::Value index = builder.createIntegerConstant(loc, idxTy, i + 1);
1130         mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
1131             loc, returnRefTy, resultArr, index);
1132         mlir::Value fortranIndex = builder.create<fir::ConvertOp>(
1133             loc, resultElemTy, oneBasedIndices[i]);
1134         builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
1135       }
1136       builder.create<fir::ResultOp>(loc, elem);
1137       builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1138       builder.create<fir::ResultOp>(loc, reduction);
1139       builder.setInsertionPointAfter(ifOp);
1140 
1141       // Close the mask if
1142       builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
1143       builder.setInsertionPointToStart(&maskIfOp.getElseRegion().front());
1144       builder.create<fir::ResultOp>(loc, reduction);
1145       builder.setInsertionPointAfter(maskIfOp);
1146 
1147       return maskIfOp.getResult(0);
1148     };
1149     auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
1150                         const mlir::Type &resultElemType, mlir::Value resultArr,
1151                         mlir::Value index) {
1152       mlir::Type resultRefTy = builder.getRefType(resultElemType);
1153       mlir::Value oneIdx =
1154           builder.createIntegerConstant(loc, builder.getIndexType(), 1);
1155       index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
1156       return builder.create<hlfir::DesignateOp>(loc, resultRefTy, resultArr,
1157                                                 index);
1158     };
1159 
1160     // Initialize the result
1161     mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
1162     mlir::Type resultRefTy = builder.getRefType(resultElemTy);
1163     mlir::Value returnValue =
1164         builder.createIntegerConstant(loc, resultElemTy, 0);
1165     for (unsigned int i = 0; i < rank; ++i) {
1166       mlir::Value index =
1167           builder.createIntegerConstant(loc, builder.getIndexType(), i + 1);
1168       mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
1169           loc, resultRefTy, resultArr, index);
1170       builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
1171     }
1172 
1173     fir::genMinMaxlocReductionLoop(builder, array, init, genBodyOp, getAddrFn,
1174                                    rank, elementType, loc, builder.getI1Type(),
1175                                    resultArr, false);
1176 
1177     mlir::Value asExpr = builder.create<hlfir::AsExprOp>(
1178         loc, resultArr, builder.createBool(loc, false));
1179 
1180     // Check all the users - the destroy is no longer required, and any assign
1181     // can use resultArr directly so that InlineHLFIRAssign pass
1182     // can optimize the results. Other operations are replaced with an AsExpr
1183     // for the temporary resultArr.
1184     llvm::SmallVector<hlfir::DestroyOp> destroys;
1185     llvm::SmallVector<hlfir::AssignOp> assigns;
1186     for (auto user : mloc->getUsers()) {
1187       if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
1188         destroys.push_back(destroy);
1189       else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
1190         assigns.push_back(assign);
1191     }
1192 
1193     // Check if the minloc/maxloc was the only user of the elemental (apart from
1194     // a destroy), and remove it if so.
1195     mlir::Operation::user_range elemUsers = elemental->getUsers();
1196     hlfir::DestroyOp elemDestroy;
1197     if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
1198       elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
1199       if (!elemDestroy)
1200         elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
1201     }
1202 
1203     for (auto d : destroys)
1204       rewriter.eraseOp(d);
1205     for (auto a : assigns)
1206       a.setOperand(0, resultArr);
1207     rewriter.replaceOp(mloc, asExpr);
1208     if (elemDestroy) {
1209       rewriter.eraseOp(elemDestroy);
1210       rewriter.eraseOp(elemental);
1211     }
1212     return mlir::success();
1213   }
1214 };
1215 
1216 class EvaluateIntoMemoryAssignBufferization
1217     : public mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp> {
1218 
1219 public:
1220   using mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp>::OpRewritePattern;
1221 
1222   llvm::LogicalResult
1223   matchAndRewrite(hlfir::EvaluateInMemoryOp,
1224                   mlir::PatternRewriter &rewriter) const override;
1225 };
1226 
1227 static llvm::LogicalResult
1228 tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
1229                           mlir::PatternRewriter &rewriter) {
1230   mlir::Location loc = evalInMem.getLoc();
1231   hlfir::DestroyOp destroy;
1232   hlfir::AssignOp assign;
1233   for (auto user : llvm::enumerate(evalInMem->getUsers())) {
1234     if (user.index() > 2)
1235       return mlir::failure();
1236     mlir::TypeSwitch<mlir::Operation *, void>(user.value())
1237         .Case([&](hlfir::AssignOp op) { assign = op; })
1238         .Case([&](hlfir::DestroyOp op) { destroy = op; });
1239   }
1240   if (!assign || !destroy || destroy.mustFinalizeExpr() ||
1241       assign.isAllocatableAssignment())
1242     return mlir::failure();
1243 
1244   hlfir::Entity lhs{assign.getLhs()};
1245   // EvaluateInMemoryOp memory is contiguous, so in general, it can only be
1246   // replace by the LHS if the LHS is contiguous.
1247   if (!lhs.isSimplyContiguous())
1248     return mlir::failure();
1249   // Character assignment may involves truncation/padding, so the LHS
1250   // cannot be used to evaluate RHS in place without proving the LHS and
1251   // RHS lengths are the same.
1252   if (lhs.isCharacter())
1253     return mlir::failure();
1254   fir::AliasAnalysis aliasAnalysis;
1255   // The region must not read or write the LHS.
1256   // Note that getModRef is used instead of mlir::MemoryEffects because
1257   // EvaluateInMemoryOp is typically expected to hold fir.calls and that
1258   // Fortran calls cannot be modeled in a useful way with mlir::MemoryEffects:
1259   // it is hard/impossible to list all the read/written SSA values in a call,
1260   // but it is often possible to tell that an SSA value cannot be accessed,
1261   // hence getModRef is needed here and below. Also note that getModRef uses
1262   // mlir::MemoryEffects for operations that do not have special handling in
1263   // getModRef.
1264   if (aliasAnalysis.getModRef(evalInMem.getBody(), lhs).isModOrRef())
1265     return mlir::failure();
1266   // Any variables affected between the hlfir.evalInMem and assignment must not
1267   // be read or written inside the region since it will be moved at the
1268   // assignment insertion point.
1269   auto effects = getEffectsBetween(evalInMem->getNextNode(), assign);
1270   if (!effects) {
1271     LLVM_DEBUG(
1272         llvm::dbgs()
1273         << "operation with unknown effects between eval_in_mem and assign\n");
1274     return mlir::failure();
1275   }
1276   for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
1277     mlir::Value affected = effect.getValue();
1278     if (!affected ||
1279         aliasAnalysis.getModRef(evalInMem.getBody(), affected).isModOrRef())
1280       return mlir::failure();
1281   }
1282 
1283   rewriter.setInsertionPoint(assign);
1284   fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
1285   mlir::Value rawLhs = hlfir::genVariableRawAddress(loc, builder, lhs);
1286   hlfir::computeEvaluateOpIn(loc, builder, evalInMem, rawLhs);
1287   rewriter.eraseOp(assign);
1288   rewriter.eraseOp(destroy);
1289   rewriter.eraseOp(evalInMem);
1290   return mlir::success();
1291 }
1292 
1293 llvm::LogicalResult EvaluateIntoMemoryAssignBufferization::matchAndRewrite(
1294     hlfir::EvaluateInMemoryOp evalInMem,
1295     mlir::PatternRewriter &rewriter) const {
1296   if (mlir::succeeded(tryUsingAssignLhsDirectly(evalInMem, rewriter)))
1297     return mlir::success();
1298   // Rewrite to temp + as_expr here so that the assign + as_expr pattern can
1299   // kick-in for simple types and at least implement the assignment inline
1300   // instead of call Assign runtime.
1301   fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
1302   mlir::Location loc = evalInMem.getLoc();
1303   auto [temp, isHeapAllocated] = hlfir::computeEvaluateOpInNewTemp(
1304       loc, builder, evalInMem, evalInMem.getShape(), evalInMem.getTypeparams());
1305   rewriter.replaceOpWithNewOp<hlfir::AsExprOp>(
1306       evalInMem, temp, /*mustFree=*/builder.createBool(loc, isHeapAllocated));
1307   return mlir::success();
1308 }
1309 
1310 class OptimizedBufferizationPass
1311     : public hlfir::impl::OptimizedBufferizationBase<
1312           OptimizedBufferizationPass> {
1313 public:
1314   void runOnOperation() override {
1315     mlir::MLIRContext *context = &getContext();
1316 
1317     mlir::GreedyRewriteConfig config;
1318     // Prevent the pattern driver from merging blocks
1319     config.enableRegionSimplification =
1320         mlir::GreedySimplifyRegionLevel::Disabled;
1321 
1322     mlir::RewritePatternSet patterns(context);
1323     // TODO: right now the patterns are non-conflicting,
1324     // but it might be better to run this pass on hlfir.assign
1325     // operations and decide which transformation to apply
1326     // at one place (e.g. we may use some heuristics and
1327     // choose different optimization strategies).
1328     // This requires small code reordering in ElementalAssignBufferization.
1329     patterns.insert<ElementalAssignBufferization>(context);
1330     patterns.insert<BroadcastAssignBufferization>(context);
1331     patterns.insert<EvaluateIntoMemoryAssignBufferization>(context);
1332     patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
1333     patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
1334     patterns.insert<ReductionConversion<hlfir::AllOp>>(context);
1335     // TODO: implement basic minloc/maxloc conversion.
1336     // patterns.insert<ReductionConversion<hlfir::MaxlocOp>>(context);
1337     // patterns.insert<ReductionConversion<hlfir::MinlocOp>>(context);
1338     patterns.insert<ReductionConversion<hlfir::MaxvalOp>>(context);
1339     patterns.insert<ReductionConversion<hlfir::MinvalOp>>(context);
1340     patterns.insert<ReductionMaskConversion<hlfir::MinlocOp>>(context);
1341     patterns.insert<ReductionMaskConversion<hlfir::MaxlocOp>>(context);
1342     // TODO: implement masked minval/maxval conversion.
1343     // patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context);
1344     // patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context);
1345 
1346     if (mlir::failed(mlir::applyPatternsGreedily(
1347             getOperation(), std::move(patterns), config))) {
1348       mlir::emitError(getOperation()->getLoc(),
1349                       "failure in HLFIR optimized bufferization");
1350       signalPassFailure();
1351     }
1352   }
1353 };
1354 } // namespace
1355