xref: /llvm-project/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp (revision c4891089125d4ba312204cc9a666339abbfc4db2)
1 //===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/HLFIRTools.h"
11 #include "flang/Optimizer/Builder/IntrinsicCall.h"
12 #include "flang/Optimizer/Builder/Todo.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIROps.h"
15 #include "flang/Optimizer/Dialect/FIRType.h"
16 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
17 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
18 #include "flang/Optimizer/HLFIR/HLFIROps.h"
19 #include "flang/Optimizer/HLFIR/Passes.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include <optional>
27 
28 namespace hlfir {
29 #define GEN_PASS_DEF_LOWERHLFIRINTRINSICS
30 #include "flang/Optimizer/HLFIR/Passes.h.inc"
31 } // namespace hlfir
32 
33 namespace {
34 
35 /// Base class for passes converting transformational intrinsic operations into
36 /// runtime calls
37 template <class OP>
38 class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
39 public:
40   explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
41       : mlir::OpRewritePattern<OP>{ctx} {
42     // required for cases where intrinsics are chained together e.g.
43     // matmul(matmul(a, b), c)
44     // because converting the inner operation then invalidates the
45     // outer operation: causing the pattern to apply recursively.
46     //
47     // This is safe because we always progress with each iteration. Circular
48     // applications of operations are not expressible in MLIR because we use
49     // an SSA form and one must become first. E.g.
50     // %a = hlfir.matmul %b %d
51     // %b = hlfir.matmul %a %d
52     // cannot be written.
53     // MSVC needs the this->
54     this->setHasBoundedRewriteRecursion(true);
55   }
56 
57 protected:
58   struct IntrinsicArgument {
59     mlir::Value val; // allowed to be null if the argument is absent
60     mlir::Type desiredType;
61   };
62 
63   /// Lower the arguments to the intrinsic: adding necessary boxing and
64   /// conversion to match the signature of the intrinsic in the runtime library.
65   llvm::SmallVector<fir::ExtendedValue, 3>
66   lowerArguments(mlir::Operation *op,
67                  const llvm::ArrayRef<IntrinsicArgument> &args,
68                  mlir::PatternRewriter &rewriter,
69                  const fir::IntrinsicArgumentLoweringRules *argLowering) const {
70     mlir::Location loc = op->getLoc();
71     fir::FirOpBuilder builder{rewriter, op};
72 
73     llvm::SmallVector<fir::ExtendedValue, 3> ret;
74     llvm::SmallVector<std::function<void()>, 2> cleanupFns;
75 
76     for (size_t i = 0; i < args.size(); ++i) {
77       mlir::Value arg = args[i].val;
78       mlir::Type desiredType = args[i].desiredType;
79       if (!arg) {
80         ret.emplace_back(fir::getAbsentIntrinsicArgument());
81         continue;
82       }
83       hlfir::Entity entity{arg};
84 
85       fir::ArgLoweringRule argRules =
86           fir::lowerIntrinsicArgumentAs(*argLowering, i);
87       switch (argRules.lowerAs) {
88       case fir::LowerIntrinsicArgAs::Value: {
89         if (args[i].desiredType != arg.getType()) {
90           arg = builder.createConvert(loc, desiredType, arg);
91           entity = hlfir::Entity{arg};
92         }
93         auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity);
94         if (cleanup)
95           cleanupFns.push_back(*cleanup);
96         ret.emplace_back(exv);
97       } break;
98       case fir::LowerIntrinsicArgAs::Addr: {
99         auto [exv, cleanup] =
100             hlfir::convertToAddress(loc, builder, entity, desiredType);
101         if (cleanup)
102           cleanupFns.push_back(*cleanup);
103         ret.emplace_back(exv);
104       } break;
105       case fir::LowerIntrinsicArgAs::Box: {
106         auto [box, cleanup] =
107             hlfir::convertToBox(loc, builder, entity, desiredType);
108         if (cleanup)
109           cleanupFns.push_back(*cleanup);
110         ret.emplace_back(box);
111       } break;
112       case fir::LowerIntrinsicArgAs::Inquired: {
113         if (args[i].desiredType != arg.getType()) {
114           arg = builder.createConvert(loc, desiredType, arg);
115           entity = hlfir::Entity{arg};
116         }
117         // Place hlfir.expr in memory, and unbox fir.boxchar. Other entities
118         // are translated to fir::ExtendedValue without transofrmation (notably,
119         // pointers/allocatable are not dereferenced).
120         // TODO: once lowering to FIR retires, UBOUND and LBOUND can be
121         // simplified since the fir.box lowered here are now guarenteed to
122         // contain the local lower bounds thanks to the hlfir.declare (the extra
123         // rebox can be removed).
124         auto [exv, cleanup] =
125             hlfir::translateToExtendedValue(loc, builder, entity);
126         if (cleanup)
127           cleanupFns.push_back(*cleanup);
128         ret.emplace_back(exv);
129       } break;
130       }
131     }
132 
133     if (cleanupFns.size()) {
134       auto oldInsertionPoint = builder.saveInsertionPoint();
135       builder.setInsertionPointAfter(op);
136       for (std::function<void()> cleanup : cleanupFns)
137         cleanup();
138       builder.restoreInsertionPoint(oldInsertionPoint);
139     }
140 
141     return ret;
142   }
143 
144   void processReturnValue(mlir::Operation *op,
145                           const fir::ExtendedValue &resultExv, bool mustBeFreed,
146                           fir::FirOpBuilder &builder,
147                           mlir::PatternRewriter &rewriter) const {
148     mlir::Location loc = op->getLoc();
149 
150     mlir::Value firBase = fir::getBase(resultExv);
151     mlir::Type firBaseTy = firBase.getType();
152 
153     std::optional<hlfir::EntityWithAttributes> resultEntity;
154     if (fir::isa_trivial(firBaseTy)) {
155       // Some intrinsics return i1 when the original operation
156       // produces fir.logical<>, so we may need to cast it.
157       firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase);
158       resultEntity = hlfir::EntityWithAttributes{firBase};
159     } else {
160       resultEntity =
161           hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result",
162                             fir::FortranVariableFlagsAttr{});
163     }
164 
165     if (resultEntity->isVariable()) {
166       hlfir::AsExprOp asExpr = builder.create<hlfir::AsExprOp>(
167           loc, *resultEntity, builder.createBool(loc, mustBeFreed));
168       resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()};
169     }
170 
171     mlir::Value base = resultEntity->getBase();
172     if (!mlir::isa<hlfir::ExprType>(base.getType())) {
173       for (mlir::Operation *use : op->getResult(0).getUsers()) {
174         if (mlir::isa<hlfir::DestroyOp>(use))
175           rewriter.eraseOp(use);
176       }
177     }
178 
179     rewriter.replaceOp(op, base);
180   }
181 };
182 
183 // Given an integer or array of integer type, calculate the Kind parameter from
184 // the width for use in runtime intrinsic calls.
185 static unsigned getKindForType(mlir::Type ty) {
186   mlir::Type eltty = hlfir::getFortranElementType(ty);
187   unsigned width = mlir::cast<mlir::IntegerType>(eltty).getWidth();
188   return width / 8;
189 }
190 
191 template <class OP>
192 class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
193   using HlfirIntrinsicConversion<OP>::HlfirIntrinsicConversion;
194   using IntrinsicArgument =
195       typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
196   using HlfirIntrinsicConversion<OP>::lowerArguments;
197   using HlfirIntrinsicConversion<OP>::processReturnValue;
198 
199 protected:
200   auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
201                           mlir::PatternRewriter &rewriter,
202                           std::string opName) const {
203     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
204     inArgs.push_back({operation.getArray(), operation.getArray().getType()});
205     inArgs.push_back({operation.getDim(), i32});
206     inArgs.push_back({operation.getMask(), logicalType});
207     auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
208     return lowerArguments(operation, inArgs, rewriter, argLowering);
209   };
210 
211   auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
212                           mlir::PatternRewriter &rewriter, std::string opName,
213                           fir::FirOpBuilder builder) const {
214     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
215     inArgs.push_back({operation.getArray(), operation.getArray().getType()});
216     inArgs.push_back({operation.getDim(), i32});
217     inArgs.push_back({operation.getMask(), logicalType});
218     mlir::Value kind = builder.createIntegerConstant(
219         operation->getLoc(), i32, getKindForType(operation.getType()));
220     inArgs.push_back({kind, i32});
221     inArgs.push_back({operation.getBack(), i32});
222     auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
223     return lowerArguments(operation, inArgs, rewriter, argLowering);
224   };
225 
226   auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
227                         mlir::PatternRewriter &rewriter,
228                         std::string opName) const {
229     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
230     inArgs.push_back({operation.getMask(), logicalType});
231     inArgs.push_back({operation.getDim(), i32});
232     auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
233     return lowerArguments(operation, inArgs, rewriter, argLowering);
234   };
235 
236 public:
237   llvm::LogicalResult
238   matchAndRewrite(OP operation,
239                   mlir::PatternRewriter &rewriter) const override {
240     std::string opName;
241     if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
242       opName = "sum";
243     } else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) {
244       opName = "product";
245     } else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) {
246       opName = "maxval";
247     } else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) {
248       opName = "minval";
249     } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
250       opName = "minloc";
251     } else if constexpr (std::is_same_v<OP, hlfir::MaxlocOp>) {
252       opName = "maxloc";
253     } else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
254       opName = "any";
255     } else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
256       opName = "all";
257     } else {
258       return mlir::failure();
259     }
260 
261     fir::FirOpBuilder builder{rewriter, operation.getOperation()};
262     const mlir::Location &loc = operation->getLoc();
263 
264     mlir::Type i32 = builder.getI32Type();
265     mlir::Type logicalType = fir::LogicalType::get(
266         builder.getContext(), builder.getKindMap().defaultLogicalKind());
267 
268     llvm::SmallVector<fir::ExtendedValue, 0> args;
269 
270     if constexpr (std::is_same_v<OP, hlfir::SumOp> ||
271                   std::is_same_v<OP, hlfir::ProductOp> ||
272                   std::is_same_v<OP, hlfir::MaxvalOp> ||
273                   std::is_same_v<OP, hlfir::MinvalOp>) {
274       args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
275     } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
276                          std::is_same_v<OP, hlfir::MaxlocOp>) {
277       args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
278                                 builder);
279     } else {
280       args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
281     }
282 
283     mlir::Type scalarResultType =
284         hlfir::getFortranElementType(operation.getType());
285 
286     auto [resultExv, mustBeFreed] =
287         fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args);
288 
289     processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter);
290     return mlir::success();
291   }
292 };
293 
294 using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>;
295 
296 using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>;
297 
298 using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;
299 
300 using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>;
301 
302 using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>;
303 
304 using MaxlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxlocOp>;
305 
306 using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;
307 
308 using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
309 
310 struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
311   using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;
312 
313   llvm::LogicalResult
314   matchAndRewrite(hlfir::CountOp count,
315                   mlir::PatternRewriter &rewriter) const override {
316     fir::FirOpBuilder builder{rewriter, count.getOperation()};
317     const mlir::Location &loc = count->getLoc();
318 
319     mlir::Type i32 = builder.getI32Type();
320     mlir::Type logicalType = fir::LogicalType::get(
321         builder.getContext(), builder.getKindMap().defaultLogicalKind());
322 
323     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
324     inArgs.push_back({count.getMask(), logicalType});
325     inArgs.push_back({count.getDim(), i32});
326     mlir::Value kind = builder.createIntegerConstant(
327         count->getLoc(), i32, getKindForType(count.getType()));
328     inArgs.push_back({kind, i32});
329 
330     auto *argLowering = fir::getIntrinsicArgumentLowering("count");
331     llvm::SmallVector<fir::ExtendedValue, 3> args =
332         lowerArguments(count, inArgs, rewriter, argLowering);
333 
334     mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType());
335 
336     auto [resultExv, mustBeFreed] =
337         fir::genIntrinsicCall(builder, loc, "count", scalarResultType, args);
338 
339     processReturnValue(count, resultExv, mustBeFreed, builder, rewriter);
340     return mlir::success();
341   }
342 };
343 
344 struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
345   using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
346 
347   llvm::LogicalResult
348   matchAndRewrite(hlfir::MatmulOp matmul,
349                   mlir::PatternRewriter &rewriter) const override {
350     fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
351     const mlir::Location &loc = matmul->getLoc();
352 
353     mlir::Value lhs = matmul.getLhs();
354     mlir::Value rhs = matmul.getRhs();
355     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
356     inArgs.push_back({lhs, lhs.getType()});
357     inArgs.push_back({rhs, rhs.getType()});
358 
359     auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
360     llvm::SmallVector<fir::ExtendedValue, 2> args =
361         lowerArguments(matmul, inArgs, rewriter, argLowering);
362 
363     mlir::Type scalarResultType =
364         hlfir::getFortranElementType(matmul.getType());
365 
366     auto [resultExv, mustBeFreed] =
367         fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args);
368 
369     processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter);
370     return mlir::success();
371   }
372 };
373 
374 struct DotProductOpConversion
375     : public HlfirIntrinsicConversion<hlfir::DotProductOp> {
376   using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
377 
378   llvm::LogicalResult
379   matchAndRewrite(hlfir::DotProductOp dotProduct,
380                   mlir::PatternRewriter &rewriter) const override {
381     fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
382     const mlir::Location &loc = dotProduct->getLoc();
383 
384     mlir::Value lhs = dotProduct.getLhs();
385     mlir::Value rhs = dotProduct.getRhs();
386     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
387     inArgs.push_back({lhs, lhs.getType()});
388     inArgs.push_back({rhs, rhs.getType()});
389 
390     auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product");
391     llvm::SmallVector<fir::ExtendedValue, 2> args =
392         lowerArguments(dotProduct, inArgs, rewriter, argLowering);
393 
394     mlir::Type scalarResultType =
395         hlfir::getFortranElementType(dotProduct.getType());
396 
397     auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
398         builder, loc, "dot_product", scalarResultType, args);
399 
400     processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter);
401     return mlir::success();
402   }
403 };
404 
405 class TransposeOpConversion
406     : public HlfirIntrinsicConversion<hlfir::TransposeOp> {
407   using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
408 
409   llvm::LogicalResult
410   matchAndRewrite(hlfir::TransposeOp transpose,
411                   mlir::PatternRewriter &rewriter) const override {
412     fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
413     const mlir::Location &loc = transpose->getLoc();
414 
415     mlir::Value arg = transpose.getArray();
416     llvm::SmallVector<IntrinsicArgument, 1> inArgs;
417     inArgs.push_back({arg, arg.getType()});
418 
419     auto *argLowering = fir::getIntrinsicArgumentLowering("transpose");
420     llvm::SmallVector<fir::ExtendedValue, 1> args =
421         lowerArguments(transpose, inArgs, rewriter, argLowering);
422 
423     mlir::Type scalarResultType =
424         hlfir::getFortranElementType(transpose.getType());
425 
426     auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
427         builder, loc, "transpose", scalarResultType, args);
428 
429     processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter);
430     return mlir::success();
431   }
432 };
433 
434 struct MatmulTransposeOpConversion
435     : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
436   using HlfirIntrinsicConversion<
437       hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
438 
439   llvm::LogicalResult
440   matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
441                   mlir::PatternRewriter &rewriter) const override {
442     fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
443     const mlir::Location &loc = multranspose->getLoc();
444 
445     mlir::Value lhs = multranspose.getLhs();
446     mlir::Value rhs = multranspose.getRhs();
447     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
448     inArgs.push_back({lhs, lhs.getType()});
449     inArgs.push_back({rhs, rhs.getType()});
450 
451     auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
452     llvm::SmallVector<fir::ExtendedValue, 2> args =
453         lowerArguments(multranspose, inArgs, rewriter, argLowering);
454 
455     mlir::Type scalarResultType =
456         hlfir::getFortranElementType(multranspose.getType());
457 
458     auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
459         builder, loc, "matmul_transpose", scalarResultType, args);
460 
461     processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
462     return mlir::success();
463   }
464 };
465 
466 class CShiftOpConversion : public HlfirIntrinsicConversion<hlfir::CShiftOp> {
467   using HlfirIntrinsicConversion<hlfir::CShiftOp>::HlfirIntrinsicConversion;
468 
469   llvm::LogicalResult
470   matchAndRewrite(hlfir::CShiftOp cshift,
471                   mlir::PatternRewriter &rewriter) const override {
472     fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
473     const mlir::Location &loc = cshift->getLoc();
474 
475     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
476     mlir::Value array = cshift.getArray();
477     inArgs.push_back({array, array.getType()});
478     mlir::Value shift = cshift.getShift();
479     inArgs.push_back({shift, shift.getType()});
480     inArgs.push_back({cshift.getDim(), builder.getI32Type()});
481 
482     auto *argLowering = fir::getIntrinsicArgumentLowering("cshift");
483     llvm::SmallVector<fir::ExtendedValue, 3> args =
484         lowerArguments(cshift, inArgs, rewriter, argLowering);
485 
486     mlir::Type scalarResultType =
487         hlfir::getFortranElementType(cshift.getType());
488 
489     auto [resultExv, mustBeFreed] =
490         fir::genIntrinsicCall(builder, loc, "cshift", scalarResultType, args);
491 
492     processReturnValue(cshift, resultExv, mustBeFreed, builder, rewriter);
493     return mlir::success();
494   }
495 };
496 
497 class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> {
498   using HlfirIntrinsicConversion<hlfir::ReshapeOp>::HlfirIntrinsicConversion;
499 
500   llvm::LogicalResult
501   matchAndRewrite(hlfir::ReshapeOp reshape,
502                   mlir::PatternRewriter &rewriter) const override {
503     fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
504     const mlir::Location &loc = reshape->getLoc();
505 
506     llvm::SmallVector<IntrinsicArgument, 4> inArgs;
507     mlir::Value array = reshape.getArray();
508     inArgs.push_back({array, array.getType()});
509     mlir::Value shape = reshape.getShape();
510     inArgs.push_back({shape, shape.getType()});
511     mlir::Type noneType = builder.getNoneType();
512     mlir::Value pad = reshape.getPad();
513     inArgs.push_back({pad, pad ? pad.getType() : noneType});
514     mlir::Value order = reshape.getOrder();
515     inArgs.push_back({order, order ? order.getType() : noneType});
516 
517     auto *argLowering = fir::getIntrinsicArgumentLowering("reshape");
518     llvm::SmallVector<fir::ExtendedValue, 4> args =
519         lowerArguments(reshape, inArgs, rewriter, argLowering);
520 
521     mlir::Type scalarResultType =
522         hlfir::getFortranElementType(reshape.getType());
523 
524     auto [resultExv, mustBeFreed] =
525         fir::genIntrinsicCall(builder, loc, "reshape", scalarResultType, args);
526 
527     processReturnValue(reshape, resultExv, mustBeFreed, builder, rewriter);
528     return mlir::success();
529   }
530 };
531 
532 class LowerHLFIRIntrinsics
533     : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
534 public:
535   void runOnOperation() override {
536     mlir::ModuleOp module = this->getOperation();
537     mlir::MLIRContext *context = &getContext();
538     mlir::RewritePatternSet patterns(context);
539     patterns.insert<
540         MatmulOpConversion, MatmulTransposeOpConversion, AllOpConversion,
541         AnyOpConversion, SumOpConversion, ProductOpConversion,
542         TransposeOpConversion, CountOpConversion, DotProductOpConversion,
543         MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion,
544         MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context);
545 
546     // While conceptually this pass is performing dialect conversion, we use
547     // pattern rewrites here instead of dialect conversion because this pass
548     // looses array bounds from some of the expressions e.g.
549     // !hlfir.expr<2xi32> -> !hlfir.expr<?xi32>
550     // MLIR thinks this is a different type so dialect conversion fails.
551     // Pattern rewriting only requires that the resulting IR is still valid
552     mlir::GreedyRewriteConfig config;
553     // Prevent the pattern driver from merging blocks
554     config.enableRegionSimplification =
555         mlir::GreedySimplifyRegionLevel::Disabled;
556 
557     if (mlir::failed(
558             mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
559       mlir::emitError(mlir::UnknownLoc::get(context),
560                       "failure in HLFIR intrinsic lowering");
561       signalPassFailure();
562     }
563   }
564 };
565 } // namespace
566