xref: /llvm-project/flang/lib/Lower/HlfirIntrinsics.cpp (revision c4891089125d4ba312204cc9a666339abbfc4db2)
1 //===-- HlfirIntrinsics.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/HlfirIntrinsics.h"
14 
15 #include "flang/Optimizer/Builder/BoxValue.h"
16 #include "flang/Optimizer/Builder/FIRBuilder.h"
17 #include "flang/Optimizer/Builder/HLFIRTools.h"
18 #include "flang/Optimizer/Builder/IntrinsicCall.h"
19 #include "flang/Optimizer/Builder/MutableBox.h"
20 #include "flang/Optimizer/Builder/Todo.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
23 #include "flang/Optimizer/HLFIR/HLFIROps.h"
24 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include <mlir/IR/ValueRange.h>
27 
28 namespace {
29 
30 class HlfirTransformationalIntrinsic {
31 public:
32   explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder,
33                                           mlir::Location loc)
34       : builder(builder), loc(loc) {}
35 
36   virtual ~HlfirTransformationalIntrinsic() = default;
37 
38   hlfir::EntityWithAttributes
39   lower(const Fortran::lower::PreparedActualArguments &loweredActuals,
40         const fir::IntrinsicArgumentLoweringRules *argLowering,
41         mlir::Type stmtResultType) {
42     mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType);
43     for (const hlfir::CleanupFunction &fn : cleanupFns)
44       fn();
45     return {hlfir::EntityWithAttributes{res}};
46   }
47 
48 protected:
49   fir::FirOpBuilder &builder;
50   mlir::Location loc;
51   llvm::SmallVector<hlfir::CleanupFunction, 3> cleanupFns;
52 
53   virtual mlir::Value
54   lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
55             const fir::IntrinsicArgumentLoweringRules *argLowering,
56             mlir::Type stmtResultType) = 0;
57 
58   llvm::SmallVector<mlir::Value> getOperandVector(
59       const Fortran::lower::PreparedActualArguments &loweredActuals,
60       const fir::IntrinsicArgumentLoweringRules *argLowering);
61 
62   mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType);
63 
64   template <typename OP, typename... BUILD_ARGS>
65   inline OP createOp(BUILD_ARGS... args) {
66     return builder.create<OP>(loc, args...);
67   }
68 
69   mlir::Value loadBoxAddress(
70       const std::optional<Fortran::lower::PreparedActualArgument> &arg);
71 
72   void addCleanup(std::optional<hlfir::CleanupFunction> cleanup) {
73     if (cleanup)
74       cleanupFns.emplace_back(std::move(*cleanup));
75   }
76 };
77 
78 template <typename OP, bool HAS_MASK>
79 class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic {
80 public:
81   using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
82 
83 protected:
84   mlir::Value
85   lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
86             const fir::IntrinsicArgumentLoweringRules *argLowering,
87             mlir::Type stmtResultType) override;
88 };
89 using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>;
90 using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>;
91 using HlfirMaxvalLowering = HlfirReductionIntrinsic<hlfir::MaxvalOp, true>;
92 using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>;
93 using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
94 using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;
95 
96 template <typename OP>
97 class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic {
98 public:
99   using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
100 
101 protected:
102   mlir::Value
103   lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
104             const fir::IntrinsicArgumentLoweringRules *argLowering,
105             mlir::Type stmtResultType) override;
106 };
107 using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>;
108 using HlfirMaxlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MaxlocOp>;
109 
110 template <typename OP>
111 class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic {
112 public:
113   using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
114 
115 protected:
116   mlir::Value
117   lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
118             const fir::IntrinsicArgumentLoweringRules *argLowering,
119             mlir::Type stmtResultType) override;
120 };
121 using HlfirMatmulLowering = HlfirProductIntrinsic<hlfir::MatmulOp>;
122 using HlfirDotProductLowering = HlfirProductIntrinsic<hlfir::DotProductOp>;
123 
124 class HlfirTransposeLowering : public HlfirTransformationalIntrinsic {
125 public:
126   using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
127 
128 protected:
129   mlir::Value
130   lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
131             const fir::IntrinsicArgumentLoweringRules *argLowering,
132             mlir::Type stmtResultType) override;
133 };
134 
135 class HlfirCountLowering : public HlfirTransformationalIntrinsic {
136 public:
137   using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
138 
139 protected:
140   mlir::Value
141   lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
142             const fir::IntrinsicArgumentLoweringRules *argLowering,
143             mlir::Type stmtResultType) override;
144 };
145 
146 class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic {
147 public:
148   HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc,
149                             hlfir::CharExtremumPredicate pred)
150       : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {}
151 
152 protected:
153   mlir::Value
154   lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
155             const fir::IntrinsicArgumentLoweringRules *argLowering,
156             mlir::Type stmtResultType) override;
157 
158 protected:
159   hlfir::CharExtremumPredicate pred;
160 };
161 
162 class HlfirCShiftLowering : public HlfirTransformationalIntrinsic {
163 public:
164   using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
165 
166 protected:
167   mlir::Value
168   lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
169             const fir::IntrinsicArgumentLoweringRules *argLowering,
170             mlir::Type stmtResultType) override;
171 };
172 
173 class HlfirReshapeLowering : public HlfirTransformationalIntrinsic {
174 public:
175   using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
176 
177 protected:
178   mlir::Value
179   lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
180             const fir::IntrinsicArgumentLoweringRules *argLowering,
181             mlir::Type stmtResultType) override;
182 };
183 
184 } // namespace
185 
186 mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
187     const std::optional<Fortran::lower::PreparedActualArgument> &arg) {
188   if (!arg)
189     return mlir::Value{};
190 
191   hlfir::Entity actual = arg->getActual(loc, builder);
192 
193   if (!arg->handleDynamicOptional()) {
194     if (actual.isMutableBox()) {
195       // this is a box address type but is not dynamically optional. Just load
196       // the box, assuming it is well formed (!fir.ref<!fir.box<...>> ->
197       // !fir.box<...>)
198       return builder.create<fir::LoadOp>(loc, actual.getBase());
199     }
200     return actual;
201   }
202 
203   auto [exv, cleanup] = hlfir::translateToExtendedValue(loc, builder, actual);
204   addCleanup(cleanup);
205 
206   mlir::Value isPresent = arg->getIsPresent();
207   // createBox will not do create any invalid memory dereferences if exv is
208   // absent. The created fir.box will not be usable, but the SelectOp below
209   // ensures it won't be.
210   mlir::Value box = builder.createBox(loc, exv);
211   mlir::Type boxType = box.getType();
212   auto absent = builder.create<fir::AbsentOp>(loc, boxType);
213   auto boxOrAbsent = builder.create<mlir::arith::SelectOp>(
214       loc, boxType, isPresent, box, absent);
215 
216   return boxOrAbsent;
217 }
218 
219 static mlir::Value loadOptionalValue(
220     mlir::Location loc, fir::FirOpBuilder &builder,
221     const std::optional<Fortran::lower::PreparedActualArgument> &arg,
222     hlfir::Entity actual) {
223   if (!arg->handleDynamicOptional())
224     return hlfir::loadTrivialScalar(loc, builder, actual);
225 
226   mlir::Value isPresent = arg->getIsPresent();
227   mlir::Type eleType = hlfir::getFortranElementType(actual.getType());
228   return builder
229       .genIfOp(loc, {eleType}, isPresent,
230                /*withElseRegion=*/true)
231       .genThen([&]() {
232         assert(actual.isScalar() && fir::isa_trivial(eleType) &&
233                "must be a numerical or logical scalar");
234         hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual);
235         builder.create<fir::ResultOp>(loc, val);
236       })
237       .genElse([&]() {
238         mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType);
239         builder.create<fir::ResultOp>(loc, zero);
240       })
241       .getResults()[0];
242 }
243 
244 llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
245     const Fortran::lower::PreparedActualArguments &loweredActuals,
246     const fir::IntrinsicArgumentLoweringRules *argLowering) {
247   llvm::SmallVector<mlir::Value> operands;
248   operands.reserve(loweredActuals.size());
249 
250   for (size_t i = 0; i < loweredActuals.size(); ++i) {
251     std::optional<Fortran::lower::PreparedActualArgument> arg =
252         loweredActuals[i];
253     if (!arg) {
254       operands.emplace_back();
255       continue;
256     }
257     hlfir::Entity actual = arg->getActual(loc, builder);
258     mlir::Value valArg;
259 
260     if (!argLowering) {
261       valArg = hlfir::loadTrivialScalar(loc, builder, actual);
262     } else {
263       fir::ArgLoweringRule argRules =
264           fir::lowerIntrinsicArgumentAs(*argLowering, i);
265       if (argRules.lowerAs == fir::LowerIntrinsicArgAs::Box)
266         valArg = loadBoxAddress(arg);
267       else if (!argRules.handleDynamicOptional &&
268                argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
269         valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
270       else if (argRules.handleDynamicOptional &&
271                argRules.lowerAs == fir::LowerIntrinsicArgAs::Value)
272         valArg = loadOptionalValue(loc, builder, arg, actual);
273       else if (argRules.handleDynamicOptional)
274         TODO(loc, "hlfir transformational intrinsic dynamically optional "
275                   "argument without box lowering");
276       else
277         valArg = actual.getBase();
278     }
279 
280     operands.emplace_back(valArg);
281   }
282   return operands;
283 }
284 
285 mlir::Type
286 HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray,
287                                                   mlir::Type stmtResultType) {
288   mlir::Type normalisedResult =
289       hlfir::getFortranElementOrSequenceType(stmtResultType);
290   if (auto array = mlir::dyn_cast<fir::SequenceType>(normalisedResult)) {
291     hlfir::ExprType::Shape resultShape =
292         hlfir::ExprType::Shape{array.getShape()};
293     mlir::Type elementType = array.getEleTy();
294     return hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
295                                 fir::isPolymorphicType(stmtResultType));
296   } else if (auto resCharType =
297                  mlir::dyn_cast<fir::CharacterType>(stmtResultType)) {
298     normalisedResult = hlfir::ExprType::get(
299         builder.getContext(), hlfir::ExprType::Shape{}, resCharType,
300         /*polymorphic=*/false);
301   }
302   return normalisedResult;
303 }
304 
305 template <typename OP, bool HAS_MASK>
306 mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl(
307     const Fortran::lower::PreparedActualArguments &loweredActuals,
308     const fir::IntrinsicArgumentLoweringRules *argLowering,
309     mlir::Type stmtResultType) {
310   auto operands = getOperandVector(loweredActuals, argLowering);
311   mlir::Value array = operands[0];
312   mlir::Value dim = operands[1];
313   // dim, mask can be NULL if these arguments are not given
314   if (dim)
315     dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
316 
317   mlir::Type resultTy = computeResultType(array, stmtResultType);
318 
319   OP op;
320   if constexpr (HAS_MASK)
321     op = createOp<OP>(resultTy, array, dim,
322                       /*mask=*/operands[2]);
323   else
324     op = createOp<OP>(resultTy, array, dim);
325   return op;
326 }
327 
328 template <typename OP>
329 mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl(
330     const Fortran::lower::PreparedActualArguments &loweredActuals,
331     const fir::IntrinsicArgumentLoweringRules *argLowering,
332     mlir::Type stmtResultType) {
333   auto operands = getOperandVector(loweredActuals, argLowering);
334   mlir::Value array = operands[0];
335   mlir::Value dim = operands[1];
336   mlir::Value mask = operands[2];
337   mlir::Value back = operands[4];
338   // dim, mask and back can be NULL if these arguments are not given.
339   if (dim)
340     dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
341   if (back)
342     back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back});
343 
344   mlir::Type resultTy = computeResultType(array, stmtResultType);
345 
346   return createOp<OP>(resultTy, array, dim, mask, back);
347 }
348 
349 template <typename OP>
350 mlir::Value HlfirProductIntrinsic<OP>::lowerImpl(
351     const Fortran::lower::PreparedActualArguments &loweredActuals,
352     const fir::IntrinsicArgumentLoweringRules *argLowering,
353     mlir::Type stmtResultType) {
354   auto operands = getOperandVector(loweredActuals, argLowering);
355   mlir::Type resultType = computeResultType(operands[0], stmtResultType);
356   return createOp<OP>(resultType, operands[0], operands[1]);
357 }
358 
359 mlir::Value HlfirTransposeLowering::lowerImpl(
360     const Fortran::lower::PreparedActualArguments &loweredActuals,
361     const fir::IntrinsicArgumentLoweringRules *argLowering,
362     mlir::Type stmtResultType) {
363   auto operands = getOperandVector(loweredActuals, argLowering);
364   hlfir::ExprType::Shape resultShape;
365   mlir::Type normalisedResult =
366       hlfir::getFortranElementOrSequenceType(stmtResultType);
367   auto array = mlir::cast<fir::SequenceType>(normalisedResult);
368   llvm::ArrayRef<int64_t> arrayShape = array.getShape();
369   assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2");
370   mlir::Type elementType = array.getEleTy();
371   resultShape.push_back(arrayShape[0]);
372   resultShape.push_back(arrayShape[1]);
373   if (auto resCharType = mlir::dyn_cast<fir::CharacterType>(elementType))
374     if (!resCharType.hasConstantLen()) {
375       // The FunctionRef expression might have imprecise character
376       // type at this point, and we can improve it by propagating
377       // the constant length from the argument.
378       auto argCharType = mlir::dyn_cast<fir::CharacterType>(
379           hlfir::getFortranElementType(operands[0].getType()));
380       if (argCharType && argCharType.hasConstantLen())
381         elementType = fir::CharacterType::get(
382             builder.getContext(), resCharType.getFKind(), argCharType.getLen());
383     }
384 
385   mlir::Type resultTy =
386       hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
387                            fir::isPolymorphicType(stmtResultType));
388   return createOp<hlfir::TransposeOp>(resultTy, operands[0]);
389 }
390 
391 mlir::Value HlfirCountLowering::lowerImpl(
392     const Fortran::lower::PreparedActualArguments &loweredActuals,
393     const fir::IntrinsicArgumentLoweringRules *argLowering,
394     mlir::Type stmtResultType) {
395   auto operands = getOperandVector(loweredActuals, argLowering);
396   mlir::Value array = operands[0];
397   mlir::Value dim = operands[1];
398   if (dim)
399     dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
400   mlir::Type resultType = computeResultType(array, stmtResultType);
401   return createOp<hlfir::CountOp>(resultType, array, dim);
402 }
403 
404 mlir::Value HlfirCharExtremumLowering::lowerImpl(
405     const Fortran::lower::PreparedActualArguments &loweredActuals,
406     const fir::IntrinsicArgumentLoweringRules *argLowering,
407     mlir::Type stmtResultType) {
408   auto operands = getOperandVector(loweredActuals, argLowering);
409   assert(operands.size() >= 2);
410   return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands});
411 }
412 
413 mlir::Value HlfirCShiftLowering::lowerImpl(
414     const Fortran::lower::PreparedActualArguments &loweredActuals,
415     const fir::IntrinsicArgumentLoweringRules *argLowering,
416     mlir::Type stmtResultType) {
417   auto operands = getOperandVector(loweredActuals, argLowering);
418   assert(operands.size() == 3);
419   mlir::Value dim = operands[2];
420   if (!dim) {
421     // If DIM is not present, drop the last element which is a null Value.
422     operands.truncate(2);
423   } else {
424     // If DIM is present, then dereference it if it is a ref.
425     dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
426     operands[2] = dim;
427   }
428 
429   mlir::Type resultType = computeResultType(operands[0], stmtResultType);
430   return createOp<hlfir::CShiftOp>(resultType, operands);
431 }
432 
433 mlir::Value HlfirReshapeLowering::lowerImpl(
434     const Fortran::lower::PreparedActualArguments &loweredActuals,
435     const fir::IntrinsicArgumentLoweringRules *argLowering,
436     mlir::Type stmtResultType) {
437   auto operands = getOperandVector(loweredActuals, argLowering);
438   assert(operands.size() == 4);
439   mlir::Type resultType = computeResultType(operands[0], stmtResultType);
440   return createOp<hlfir::ReshapeOp>(resultType, operands[0], operands[1],
441                                     operands[2], operands[3]);
442 }
443 
444 std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
445     fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name,
446     const Fortran::lower::PreparedActualArguments &loweredActuals,
447     const fir::IntrinsicArgumentLoweringRules *argLowering,
448     mlir::Type stmtResultType) {
449   // If the result is of a derived type that may need finalization,
450   // we have to use DestroyOp with 'finalize' attribute for the result
451   // of the intrinsic operation.
452   if (name == "sum")
453     return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering,
454                                                 stmtResultType);
455   if (name == "product")
456     return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering,
457                                                     stmtResultType);
458   if (name == "any")
459     return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering,
460                                                 stmtResultType);
461   if (name == "all")
462     return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering,
463                                                 stmtResultType);
464   if (name == "matmul")
465     return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering,
466                                                    stmtResultType);
467   if (name == "dot_product")
468     return HlfirDotProductLowering{builder, loc}.lower(
469         loweredActuals, argLowering, stmtResultType);
470   // FIXME: the result may need finalization.
471   if (name == "transpose")
472     return HlfirTransposeLowering{builder, loc}.lower(
473         loweredActuals, argLowering, stmtResultType);
474   if (name == "count")
475     return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering,
476                                                   stmtResultType);
477   if (name == "maxval")
478     return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering,
479                                                    stmtResultType);
480   if (name == "minval")
481     return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering,
482                                                    stmtResultType);
483   if (name == "minloc")
484     return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering,
485                                                    stmtResultType);
486   if (name == "maxloc")
487     return HlfirMaxlocLowering{builder, loc}.lower(loweredActuals, argLowering,
488                                                    stmtResultType);
489   if (name == "cshift")
490     return HlfirCShiftLowering{builder, loc}.lower(loweredActuals, argLowering,
491                                                    stmtResultType);
492   if (name == "reshape")
493     return HlfirReshapeLowering{builder, loc}.lower(loweredActuals, argLowering,
494                                                     stmtResultType);
495   if (mlir::isa<fir::CharacterType>(stmtResultType)) {
496     if (name == "min")
497       return HlfirCharExtremumLowering{builder, loc,
498                                        hlfir::CharExtremumPredicate::min}
499           .lower(loweredActuals, argLowering, stmtResultType);
500     if (name == "max")
501       return HlfirCharExtremumLowering{builder, loc,
502                                        hlfir::CharExtremumPredicate::max}
503           .lower(loweredActuals, argLowering, stmtResultType);
504   }
505   return std::nullopt;
506 }
507