xref: /llvm-project/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (revision 71ff486bee1b089c78f5b8175fef16f99fcebe19)
1 //===- SimplifyHLFIRIntrinsics.cpp - Simplify HLFIR Intrinsics ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Normally transformational intrinsics are lowered to calls to runtime
9 // functions. However, some cases of the intrinsics are faster when inlined
10 // into the calling function.
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Builder/Complex.h"
14 #include "flang/Optimizer/Builder/FIRBuilder.h"
15 #include "flang/Optimizer/Builder/HLFIRTools.h"
16 #include "flang/Optimizer/Builder/IntrinsicCall.h"
17 #include "flang/Optimizer/Dialect/FIRDialect.h"
18 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
19 #include "flang/Optimizer/HLFIR/HLFIROps.h"
20 #include "flang/Optimizer/HLFIR/Passes.h"
21 #include "mlir/Dialect/Arith/IR/Arith.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 namespace hlfir {
27 #define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
28 #include "flang/Optimizer/HLFIR/Passes.h.inc"
29 } // namespace hlfir
30 
31 #define DEBUG_TYPE "simplify-hlfir-intrinsics"
32 
33 static llvm::cl::opt<bool> forceMatmulAsElemental(
34     "flang-inline-matmul-as-elemental",
35     llvm::cl::desc("Expand hlfir.matmul as elemental operation"),
36     llvm::cl::init(false));
37 
38 namespace {
39 
40 // Helper class to generate operations related to computing
41 // product of values.
42 class ProductFactory {
43 public:
44   ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder)
45       : loc(loc), builder(builder) {}
46 
47   // Generate an update of the inner product value:
48   //   acc += v1 * v2, OR
49   //   acc += CONJ(v1) * v2, OR
50   //   acc ||= v1 && v2
51   //
52   // CONJ parameter specifies whether the first complex product argument
53   // needs to be conjugated.
54   template <bool CONJ = false>
55   mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1,
56                                    mlir::Value v2) {
57     mlir::Type resultType = acc.getType();
58     acc = castToProductType(acc, resultType);
59     v1 = castToProductType(v1, resultType);
60     v2 = castToProductType(v2, resultType);
61     mlir::Value result;
62     if (mlir::isa<mlir::FloatType>(resultType)) {
63       result = builder.create<mlir::arith::AddFOp>(
64           loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
65     } else if (mlir::isa<mlir::ComplexType>(resultType)) {
66       if constexpr (CONJ)
67         result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1);
68       else
69         result = v1;
70 
71       result = builder.create<fir::AddcOp>(
72           loc, acc, builder.create<fir::MulcOp>(loc, result, v2));
73     } else if (mlir::isa<mlir::IntegerType>(resultType)) {
74       result = builder.create<mlir::arith::AddIOp>(
75           loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
76     } else if (mlir::isa<fir::LogicalType>(resultType)) {
77       result = builder.create<mlir::arith::OrIOp>(
78           loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
79     } else {
80       llvm_unreachable("unsupported type");
81     }
82 
83     return builder.createConvert(loc, resultType, result);
84   }
85 
86 private:
87   mlir::Location loc;
88   fir::FirOpBuilder &builder;
89 
90   mlir::Value castToProductType(mlir::Value value, mlir::Type type) {
91     if (mlir::isa<fir::LogicalType>(type))
92       return builder.createConvert(loc, builder.getIntegerType(1), value);
93 
94     // TODO: the multiplications/additions by/of zero resulting from
95     // complex * real are optimized by LLVM under -fno-signed-zeros
96     // -fno-honor-nans.
97     // We can make them disappear by default if we:
98     //   * either expand the complex multiplication into real
99     //     operations, OR
100     //   * set nnan nsz fast-math flags to the complex operations.
101     if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
102       mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
103       fir::factory::Complex helper(builder, loc);
104       mlir::Type partType = helper.getComplexPartType(type);
105       return helper.insertComplexPart(zeroCmplx,
106                                       castToProductType(value, partType),
107                                       /*isImagPart=*/false);
108     }
109     return builder.createConvert(loc, type, value);
110   }
111 };
112 
113 class TransposeAsElementalConversion
114     : public mlir::OpRewritePattern<hlfir::TransposeOp> {
115 public:
116   using mlir::OpRewritePattern<hlfir::TransposeOp>::OpRewritePattern;
117 
118   llvm::LogicalResult
119   matchAndRewrite(hlfir::TransposeOp transpose,
120                   mlir::PatternRewriter &rewriter) const override {
121     hlfir::ExprType expr = transpose.getType();
122     // TODO: hlfir.elemental supports polymorphic data types now,
123     // so this can be supported.
124     if (expr.isPolymorphic())
125       return rewriter.notifyMatchFailure(transpose,
126                                          "TRANSPOSE of polymorphic type");
127 
128     mlir::Location loc = transpose.getLoc();
129     fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
130     mlir::Type elementType = expr.getElementType();
131     hlfir::Entity array = hlfir::Entity{transpose.getArray()};
132     mlir::Value resultShape = genResultShape(loc, builder, array);
133     llvm::SmallVector<mlir::Value, 1> typeParams;
134     hlfir::genLengthParameters(loc, builder, array, typeParams);
135 
136     auto genKernel = [&array](mlir::Location loc, fir::FirOpBuilder &builder,
137                               mlir::ValueRange inputIndices) -> hlfir::Entity {
138       assert(inputIndices.size() == 2 && "checked in TransposeOp::validate");
139       const std::initializer_list<mlir::Value> initList = {inputIndices[1],
140                                                            inputIndices[0]};
141       mlir::ValueRange transposedIndices(initList);
142       hlfir::Entity element =
143           hlfir::getElementAt(loc, builder, array, transposedIndices);
144       hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, element);
145       return val;
146     };
147     hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
148         loc, builder, elementType, resultShape, typeParams, genKernel,
149         /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
150         transpose.getResult().getType());
151 
152     // it wouldn't be safe to replace block arguments with a different
153     // hlfir.expr type. Types can differ due to differing amounts of shape
154     // information
155     assert(elementalOp.getResult().getType() ==
156            transpose.getResult().getType());
157 
158     rewriter.replaceOp(transpose, elementalOp);
159     return mlir::success();
160   }
161 
162 private:
163   static mlir::Value genResultShape(mlir::Location loc,
164                                     fir::FirOpBuilder &builder,
165                                     hlfir::Entity array) {
166     llvm::SmallVector<mlir::Value, 2> inExtents =
167         hlfir::genExtentsVector(loc, builder, array);
168 
169     // transpose indices
170     assert(inExtents.size() == 2 && "checked in TransposeOp::validate");
171     return builder.create<fir::ShapeOp>(
172         loc, mlir::ValueRange{inExtents[1], inExtents[0]});
173   }
174 };
175 
176 // Expand the SUM(DIM=CONSTANT) operation into .
177 class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
178 public:
179   using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
180 
181   llvm::LogicalResult
182   matchAndRewrite(hlfir::SumOp sum,
183                   mlir::PatternRewriter &rewriter) const override {
184     hlfir::Entity array = hlfir::Entity{sum.getArray()};
185     bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
186     mlir::Value dim = sum.getDim();
187     int64_t dimVal = 0;
188     if (!isTotalReduction) {
189       // In case of partial reduction we should ignore the operations
190       // with invalid DIM values. They may appear in dead code
191       // after constant propagation.
192       auto constDim = fir::getIntIfConstant(dim);
193       if (!constDim)
194         return rewriter.notifyMatchFailure(sum, "Nonconstant DIM for SUM");
195       dimVal = *constDim;
196 
197       if ((dimVal <= 0 || dimVal > array.getRank()))
198         return rewriter.notifyMatchFailure(
199             sum, "Invalid DIM for partial SUM reduction");
200     }
201 
202     mlir::Location loc = sum.getLoc();
203     fir::FirOpBuilder builder{rewriter, sum.getOperation()};
204     mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
205     mlir::Value mask = sum.getMask();
206 
207     mlir::Value resultShape, dimExtent;
208     llvm::SmallVector<mlir::Value> arrayExtents;
209     if (isTotalReduction)
210       arrayExtents = hlfir::genExtentsVector(loc, builder, array);
211     else
212       std::tie(resultShape, dimExtent) =
213           genResultShapeForPartialReduction(loc, builder, array, dimVal);
214 
215     // If the mask is present and is a scalar, then we'd better load its value
216     // outside of the reduction loop making the loop unswitching easier.
217     mlir::Value isPresentPred, maskValue;
218     if (mask) {
219       if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
220         // MASK represented by a box might be dynamically optional,
221         // so we have to check for its presence before accessing it.
222         isPresentPred =
223             builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
224       }
225 
226       if (hlfir::Entity{mask}.isScalar())
227         maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
228     }
229 
230     auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
231                          mlir::ValueRange inputIndices) -> hlfir::Entity {
232       // Loop over all indices in the DIM dimension, and reduce all values.
233       // If DIM is not present, do total reduction.
234 
235       // Initial value for the reduction.
236       mlir::Value reductionInitValue =
237           fir::factory::createZeroValue(builder, loc, elementType);
238 
239       // The reduction loop may be unordered if FastMathFlags::reassoc
240       // transformations are allowed. The integer reduction is always
241       // unordered.
242       bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
243                          static_cast<bool>(sum.getFastmath() &
244                                            mlir::arith::FastMathFlags::reassoc);
245 
246       llvm::SmallVector<mlir::Value> extents;
247       if (isTotalReduction)
248         extents = arrayExtents;
249       else
250         extents.push_back(
251             builder.createConvert(loc, builder.getIndexType(), dimExtent));
252 
253       auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
254                          mlir::ValueRange oneBasedIndices,
255                          mlir::ValueRange reductionArgs)
256           -> llvm::SmallVector<mlir::Value, 1> {
257         // Generate the reduction loop-nest body.
258         // The initial reduction value in the innermost loop
259         // is passed via reductionArgs[0].
260         llvm::SmallVector<mlir::Value> indices;
261         if (isTotalReduction) {
262           indices = oneBasedIndices;
263         } else {
264           indices = inputIndices;
265           indices.insert(indices.begin() + dimVal - 1, oneBasedIndices[0]);
266         }
267 
268         mlir::Value reductionValue = reductionArgs[0];
269         fir::IfOp ifOp;
270         if (mask) {
271           // Make the reduction value update conditional on the value
272           // of the mask.
273           if (!maskValue) {
274             // If the mask is an array, use the elemental and the loop indices
275             // to address the proper mask element.
276             maskValue =
277                 genMaskValue(loc, builder, mask, isPresentPred, indices);
278           }
279           mlir::Value isUnmasked = builder.create<fir::ConvertOp>(
280               loc, builder.getI1Type(), maskValue);
281           ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
282                                            /*withElseRegion=*/true);
283           // In the 'else' block return the current reduction value.
284           builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
285           builder.create<fir::ResultOp>(loc, reductionValue);
286 
287           // In the 'then' block do the actual addition.
288           builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
289         }
290 
291         hlfir::Entity element =
292             hlfir::getElementAt(loc, builder, array, indices);
293         hlfir::Entity elementValue =
294             hlfir::loadTrivialScalar(loc, builder, element);
295         // NOTE: we can use "Kahan summation" same way as the runtime
296         // (e.g. when fast-math is not allowed), but let's start with
297         // the simple version.
298         reductionValue =
299             genScalarAdd(loc, builder, reductionValue, elementValue);
300 
301         if (ifOp) {
302           builder.create<fir::ResultOp>(loc, reductionValue);
303           builder.setInsertionPointAfter(ifOp);
304           reductionValue = ifOp.getResult(0);
305         }
306 
307         return {reductionValue};
308       };
309 
310       llvm::SmallVector<mlir::Value, 1> reductionFinalValues =
311           hlfir::genLoopNestWithReductions(loc, builder, extents,
312                                            {reductionInitValue}, genBody,
313                                            isUnordered);
314       return hlfir::Entity{reductionFinalValues[0]};
315     };
316 
317     if (isTotalReduction) {
318       hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{});
319       rewriter.replaceOp(sum, result);
320       return mlir::success();
321     }
322 
323     hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
324         loc, builder, elementType, resultShape, {}, genKernel,
325         /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
326         sum.getResult().getType());
327 
328     // it wouldn't be safe to replace block arguments with a different
329     // hlfir.expr type. Types can differ due to differing amounts of shape
330     // information
331     assert(elementalOp.getResult().getType() == sum.getResult().getType());
332 
333     rewriter.replaceOp(sum, elementalOp);
334     return mlir::success();
335   }
336 
337 private:
338   // Return fir.shape specifying the shape of the result
339   // of a SUM reduction with DIM=dimVal. The second return value
340   // is the extent of the DIM dimension.
341   static std::tuple<mlir::Value, mlir::Value>
342   genResultShapeForPartialReduction(mlir::Location loc,
343                                     fir::FirOpBuilder &builder,
344                                     hlfir::Entity array, int64_t dimVal) {
345     llvm::SmallVector<mlir::Value> inExtents =
346         hlfir::genExtentsVector(loc, builder, array);
347     assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
348            "DIM must be present and a positive constant not exceeding "
349            "the array's rank");
350 
351     mlir::Value dimExtent = inExtents[dimVal - 1];
352     inExtents.erase(inExtents.begin() + dimVal - 1);
353     return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
354   }
355 
356   // Generate scalar addition of the two values (of the same data type).
357   static mlir::Value genScalarAdd(mlir::Location loc,
358                                   fir::FirOpBuilder &builder,
359                                   mlir::Value value1, mlir::Value value2) {
360     mlir::Type ty = value1.getType();
361     assert(ty == value2.getType() && "reduction values' types do not match");
362     if (mlir::isa<mlir::FloatType>(ty))
363       return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
364     else if (mlir::isa<mlir::ComplexType>(ty))
365       return builder.create<fir::AddcOp>(loc, value1, value2);
366     else if (mlir::isa<mlir::IntegerType>(ty))
367       return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
368 
369     llvm_unreachable("unsupported SUM reduction type");
370   }
371 
372   static mlir::Value genMaskValue(mlir::Location loc,
373                                   fir::FirOpBuilder &builder, mlir::Value mask,
374                                   mlir::Value isPresentPred,
375                                   mlir::ValueRange indices) {
376     mlir::OpBuilder::InsertionGuard guard(builder);
377     fir::IfOp ifOp;
378     mlir::Type maskType =
379         hlfir::getFortranElementType(fir::unwrapPassByRefType(mask.getType()));
380     if (isPresentPred) {
381       ifOp = builder.create<fir::IfOp>(loc, maskType, isPresentPred,
382                                        /*withElseRegion=*/true);
383 
384       // Use 'true', if the mask is not present.
385       builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
386       mlir::Value trueValue = builder.createBool(loc, true);
387       trueValue = builder.createConvert(loc, maskType, trueValue);
388       builder.create<fir::ResultOp>(loc, trueValue);
389 
390       // Load the mask value, if the mask is present.
391       builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
392     }
393 
394     hlfir::Entity maskVar{mask};
395     if (maskVar.isScalar()) {
396       if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
397         // MASK may be a boxed scalar.
398         mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, maskVar);
399         mask = builder.create<fir::LoadOp>(loc, hlfir::Entity{addr});
400       } else {
401         mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
402       }
403     } else {
404       // Load from the mask array.
405       assert(!indices.empty() && "no indices for addressing the mask array");
406       maskVar = hlfir::getElementAt(loc, builder, maskVar, indices);
407       mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
408     }
409 
410     if (!isPresentPred)
411       return mask;
412 
413     builder.create<fir::ResultOp>(loc, mask);
414     return ifOp.getResult(0);
415   }
416 };
417 
418 class CShiftAsElementalConversion
419     : public mlir::OpRewritePattern<hlfir::CShiftOp> {
420 public:
421   using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
422 
423   llvm::LogicalResult
424   matchAndRewrite(hlfir::CShiftOp cshift,
425                   mlir::PatternRewriter &rewriter) const override {
426     using Fortran::common::maxRank;
427 
428     hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
429     assert(expr &&
430            "expected an expression type for the result of hlfir.cshift");
431     unsigned arrayRank = expr.getRank();
432     // When it is a 1D CSHIFT, we may assume that the DIM argument
433     // (whether it is present or absent) is equal to 1, otherwise,
434     // the program is illegal.
435     int64_t dimVal = 1;
436     if (arrayRank != 1)
437       if (mlir::Value dim = cshift.getDim()) {
438         auto constDim = fir::getIntIfConstant(dim);
439         if (!constDim)
440           return rewriter.notifyMatchFailure(cshift,
441                                              "Nonconstant DIM for CSHIFT");
442         dimVal = *constDim;
443       }
444 
445     if (dimVal <= 0 || dimVal > arrayRank)
446       return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT");
447 
448     mlir::Location loc = cshift.getLoc();
449     fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
450     mlir::Type elementType = expr.getElementType();
451     hlfir::Entity array = hlfir::Entity{cshift.getArray()};
452     mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
453     llvm::SmallVector<mlir::Value> arrayExtents =
454         hlfir::getExplicitExtentsFromShape(arrayShape, builder);
455     llvm::SmallVector<mlir::Value, 1> typeParams;
456     hlfir::genLengthParameters(loc, builder, array, typeParams);
457     hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
458     // The new index computation involves MODULO, which is not implemented
459     // for IndexType, so use I64 instead.
460     mlir::Type calcType = builder.getI64Type();
461 
462     mlir::Value one = builder.createIntegerConstant(loc, calcType, 1);
463     mlir::Value shiftVal;
464     if (shift.isScalar()) {
465       shiftVal = hlfir::loadTrivialScalar(loc, builder, shift);
466       shiftVal = builder.createConvert(loc, calcType, shiftVal);
467     }
468 
469     auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
470                          mlir::ValueRange inputIndices) -> hlfir::Entity {
471       llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
472       if (!shift.isScalar()) {
473         // When the array is not a vector, section
474         // (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)
475         // of the result has a value equal to:
476         // CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)),
477         //        SH, 1),
478         // where SH is either SHIFT (if scalar) or
479         // SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)).
480         llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices};
481         shiftIndices.erase(shiftIndices.begin() + dimVal - 1);
482         hlfir::Entity shiftElement =
483             hlfir::getElementAt(loc, builder, shift, shiftIndices);
484         shiftVal = hlfir::loadTrivialScalar(loc, builder, shiftElement);
485         shiftVal = builder.createConvert(loc, calcType, shiftVal);
486       }
487 
488       // Element i of the result (1-based) is element
489       // 'MODULO(i + SH - 1, SIZE(ARRAY)) + 1' (1-based) of the original
490       // ARRAY (or its section, when ARRAY is not a vector).
491       mlir::Value index =
492           builder.createConvert(loc, calcType, inputIndices[dimVal - 1]);
493       mlir::Value extent = arrayExtents[dimVal - 1];
494       mlir::Value newIndex =
495           builder.create<mlir::arith::AddIOp>(loc, index, shiftVal);
496       newIndex = builder.create<mlir::arith::SubIOp>(loc, newIndex, one);
497       newIndex = fir::IntrinsicLibrary{builder, loc}.genModulo(
498           calcType, {newIndex, builder.createConvert(loc, calcType, extent)});
499       newIndex = builder.create<mlir::arith::AddIOp>(loc, newIndex, one);
500       newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex);
501 
502       indices[dimVal - 1] = newIndex;
503       hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
504       return hlfir::loadTrivialScalar(loc, builder, element);
505     };
506 
507     hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
508         loc, builder, elementType, arrayShape, typeParams, genKernel,
509         /*isUnordered=*/true,
510         array.isPolymorphic() ? static_cast<mlir::Value>(array) : nullptr,
511         cshift.getResult().getType());
512     rewriter.replaceOp(cshift, elementalOp);
513     return mlir::success();
514   }
515 };
516 
517 template <typename Op>
518 class MatmulConversion : public mlir::OpRewritePattern<Op> {
519 public:
520   using mlir::OpRewritePattern<Op>::OpRewritePattern;
521 
522   llvm::LogicalResult
523   matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override {
524     mlir::Location loc = matmul.getLoc();
525     fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
526     hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()};
527     hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()};
528     mlir::Value resultShape, innerProductExtent;
529     std::tie(resultShape, innerProductExtent) =
530         genResultShape(loc, builder, lhs, rhs);
531 
532     if (forceMatmulAsElemental || isMatmulTranspose) {
533       // Generate hlfir.elemental that produces the result of
534       // MATMUL/MATMUL(TRANSPOSE).
535       // Note that this implementation is very suboptimal for MATMUL,
536       // but is quite good for MATMUL(TRANSPOSE), e.g.:
537       //   R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N))
538       // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result
539       // in merging the inner product computation with the elemental
540       // addition. Note that the inner product computation will
541       // benefit from processing the lowermost dimensions of X and Y,
542       // which may be the best when they are contiguous.
543       //
544       // This is why we always inline MATMUL(TRANSPOSE) as an elemental.
545       // MATMUL is inlined below by default unless forceMatmulAsElemental.
546       hlfir::ExprType resultType =
547           mlir::cast<hlfir::ExprType>(matmul.getType());
548       hlfir::ElementalOp newOp = genElementalMatmul(
549           loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent);
550       rewriter.replaceOp(matmul, newOp);
551       return mlir::success();
552     }
553 
554     // Generate hlfir.eval_in_mem to mimic the MATMUL implementation
555     // from Fortran runtime. The implementation needs to operate
556     // with the result array as an in-memory object.
557     hlfir::EvaluateInMemoryOp evalOp =
558         builder.create<hlfir::EvaluateInMemoryOp>(
559             loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape);
560     builder.setInsertionPointToStart(&evalOp.getBody().front());
561 
562     // Embox the raw array pointer to simplify designating it.
563     // TODO: this currently results in redundant lower bounds
564     // addition for the designator, but this should be fixed in
565     // hlfir::Entity::mayHaveNonDefaultLowerBounds().
566     mlir::Value resultArray = evalOp.getMemory();
567     mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType());
568     resultArray = builder.createBox(loc, fir::BoxType::get(arrayType),
569                                     resultArray, resultShape, /*slice=*/nullptr,
570                                     /*lengths=*/{}, /*tdesc=*/nullptr);
571 
572     // The contiguous MATMUL version is best for the cases
573     // where the input arrays and (maybe) the result are contiguous
574     // in their lowermost dimensions.
575     // Especially, when LLVM can recognize the continuity
576     // and vectorize the loops properly.
577     // Note that the contiguous MATMUL inlining is correct
578     // even when the input arrays are not contiguous.
579     // TODO: we can try to recognize the cases when the continuity
580     // is not statically obvious and try to generate an explicitly
581     // continuous version under a dynamic check. This should allow
582     // LLVM to vectorize the loops better. Note that this can
583     // also be postponed up to the LoopVersioning pass.
584     // The fallback implementation may use genElementalMatmul() with
585     // an hlfir.assign into the result of eval_in_mem.
586     mlir::LogicalResult rewriteResult =
587         genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
588                             resultShape, lhs, rhs, innerProductExtent);
589 
590     if (mlir::failed(rewriteResult)) {
591       // Erase the unclaimed eval_in_mem op.
592       rewriter.eraseOp(evalOp);
593       return rewriter.notifyMatchFailure(matmul,
594                                          "genContiguousMatmul() failed");
595     }
596 
597     rewriter.replaceOp(matmul, evalOp);
598     return mlir::success();
599   }
600 
601 private:
602   static constexpr bool isMatmulTranspose =
603       std::is_same_v<Op, hlfir::MatmulTransposeOp>;
604 
605   // Return a tuple of:
606   //   * A fir.shape operation representing the shape of the result
607   //     of a MATMUL/MATMUL(TRANSPOSE).
608   //   * An extent of the dimensions of the input array
609   //     that are processed during the inner product computation.
610   static std::tuple<mlir::Value, mlir::Value>
611   genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
612                  hlfir::Entity input1, hlfir::Entity input2) {
613     llvm::SmallVector<mlir::Value, 2> input1Extents =
614         hlfir::genExtentsVector(loc, builder, input1);
615     llvm::SmallVector<mlir::Value, 2> input2Extents =
616         hlfir::genExtentsVector(loc, builder, input2);
617 
618     llvm::SmallVector<mlir::Value, 2> newExtents;
619     mlir::Value innerProduct1Extent, innerProduct2Extent;
620     if (input1Extents.size() == 1) {
621       assert(!isMatmulTranspose &&
622              "hlfir.matmul_transpose's first operand must be rank-2 array");
623       assert(input2Extents.size() == 2 &&
624              "hlfir.matmul second argument must be rank-2 array");
625       newExtents.push_back(input2Extents[1]);
626       innerProduct1Extent = input1Extents[0];
627       innerProduct2Extent = input2Extents[0];
628     } else {
629       if (input2Extents.size() == 1) {
630         assert(input1Extents.size() == 2 &&
631                "hlfir.matmul first argument must be rank-2 array");
632         if constexpr (isMatmulTranspose)
633           newExtents.push_back(input1Extents[1]);
634         else
635           newExtents.push_back(input1Extents[0]);
636       } else {
637         assert(input1Extents.size() == 2 && input2Extents.size() == 2 &&
638                "hlfir.matmul arguments must be rank-2 arrays");
639         if constexpr (isMatmulTranspose)
640           newExtents.push_back(input1Extents[1]);
641         else
642           newExtents.push_back(input1Extents[0]);
643 
644         newExtents.push_back(input2Extents[1]);
645       }
646       if constexpr (isMatmulTranspose)
647         innerProduct1Extent = input1Extents[0];
648       else
649         innerProduct1Extent = input1Extents[1];
650 
651       innerProduct2Extent = input2Extents[0];
652     }
653     // The inner product dimensions of the input arrays
654     // must match. Pick the best (e.g. constant) out of them
655     // so that the inner product loop bound can be used in
656     // optimizations.
657     llvm::SmallVector<mlir::Value> innerProductExtent =
658         fir::factory::deduceOptimalExtents({innerProduct1Extent},
659                                            {innerProduct2Extent});
660     return {builder.create<fir::ShapeOp>(loc, newExtents),
661             innerProductExtent[0]};
662   }
663 
664   static mlir::LogicalResult
665   genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
666                       hlfir::Entity result, mlir::Value resultShape,
667                       hlfir::Entity lhs, hlfir::Entity rhs,
668                       mlir::Value innerProductExtent) {
669     // This code does not support MATMUL(TRANSPOSE), and it is supposed
670     // to be inlined as hlfir.elemental.
671     if constexpr (isMatmulTranspose)
672       return mlir::failure();
673 
674     mlir::OpBuilder::InsertionGuard guard(builder);
675     mlir::Type resultElementType = result.getFortranElementType();
676     llvm::SmallVector<mlir::Value, 2> resultExtents =
677         mlir::cast<fir::ShapeOp>(resultShape.getDefiningOp()).getExtents();
678 
679     // The inner product loop may be unordered if FastMathFlags::reassoc
680     // transformations are allowed. The integer/logical inner product is
681     // always unordered.
682     // Note that isUnordered is currently applied to all loops
683     // in the loop nests generated below, while it has to be applied
684     // only to one.
685     bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
686                        mlir::isa<fir::LogicalType>(resultElementType) ||
687                        static_cast<bool>(builder.getFastMathFlags() &
688                                          mlir::arith::FastMathFlags::reassoc);
689 
690     // Insert the initialization loop nest that fills the whole result with
691     // zeroes.
692     mlir::Value initValue =
693         fir::factory::createZeroValue(builder, loc, resultElementType);
694     auto genInitBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
695                            mlir::ValueRange oneBasedIndices,
696                            mlir::ValueRange reductionArgs)
697         -> llvm::SmallVector<mlir::Value, 0> {
698       hlfir::Entity resultElement =
699           hlfir::getElementAt(loc, builder, result, oneBasedIndices);
700       builder.create<hlfir::AssignOp>(loc, initValue, resultElement);
701       return {};
702     };
703 
704     hlfir::genLoopNestWithReductions(loc, builder, resultExtents,
705                                      /*reductionInits=*/{}, genInitBody,
706                                      /*isUnordered=*/true);
707 
708     if (lhs.getRank() == 2 && rhs.getRank() == 2) {
709       //   LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
710       //
711       // Insert the computation loop nest:
712       //   DO 2 K = 1, N
713       //    DO 2 J = 1, NCOLS
714       //     DO 2 I = 1, NROWS
715       //   2  RESULT(I,J) = RESULT(I,J) + LHS(I,K)*RHS(K,J)
716       auto genMatrixMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder,
717                                  mlir::ValueRange oneBasedIndices,
718                                  mlir::ValueRange reductionArgs)
719           -> llvm::SmallVector<mlir::Value, 0> {
720         mlir::Value I = oneBasedIndices[0];
721         mlir::Value J = oneBasedIndices[1];
722         mlir::Value K = oneBasedIndices[2];
723         hlfir::Entity resultElement =
724             hlfir::getElementAt(loc, builder, result, {I, J});
725         hlfir::Entity resultElementValue =
726             hlfir::loadTrivialScalar(loc, builder, resultElement);
727         hlfir::Entity lhsElementValue =
728             hlfir::loadElementAt(loc, builder, lhs, {I, K});
729         hlfir::Entity rhsElementValue =
730             hlfir::loadElementAt(loc, builder, rhs, {K, J});
731         mlir::Value productValue =
732             ProductFactory{loc, builder}.genAccumulateProduct(
733                 resultElementValue, lhsElementValue, rhsElementValue);
734         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
735         return {};
736       };
737 
738       // Note that the loops are inserted in reverse order,
739       // so innerProductExtent should be passed as the last extent.
740       hlfir::genLoopNestWithReductions(
741           loc, builder,
742           {resultExtents[0], resultExtents[1], innerProductExtent},
743           /*reductionInits=*/{}, genMatrixMatrix, isUnordered);
744       return mlir::success();
745     }
746 
747     if (lhs.getRank() == 2 && rhs.getRank() == 1) {
748       //   LHS(NROWS,N) * RHS(N) -> RESULT(NROWS)
749       //
750       // Insert the computation loop nest:
751       //   DO 2 K = 1, N
752       //    DO 2 J = 1, NROWS
753       //   2 RES(J) = RES(J) + LHS(J,K)*RHS(K)
754       auto genMatrixVector = [&](mlir::Location loc, fir::FirOpBuilder &builder,
755                                  mlir::ValueRange oneBasedIndices,
756                                  mlir::ValueRange reductionArgs)
757           -> llvm::SmallVector<mlir::Value, 0> {
758         mlir::Value J = oneBasedIndices[0];
759         mlir::Value K = oneBasedIndices[1];
760         hlfir::Entity resultElement =
761             hlfir::getElementAt(loc, builder, result, {J});
762         hlfir::Entity resultElementValue =
763             hlfir::loadTrivialScalar(loc, builder, resultElement);
764         hlfir::Entity lhsElementValue =
765             hlfir::loadElementAt(loc, builder, lhs, {J, K});
766         hlfir::Entity rhsElementValue =
767             hlfir::loadElementAt(loc, builder, rhs, {K});
768         mlir::Value productValue =
769             ProductFactory{loc, builder}.genAccumulateProduct(
770                 resultElementValue, lhsElementValue, rhsElementValue);
771         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
772         return {};
773       };
774       hlfir::genLoopNestWithReductions(
775           loc, builder, {resultExtents[0], innerProductExtent},
776           /*reductionInits=*/{}, genMatrixVector, isUnordered);
777       return mlir::success();
778     }
779     if (lhs.getRank() == 1 && rhs.getRank() == 2) {
780       //   LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS)
781       //
782       // Insert the computation loop nest:
783       //   DO 2 K = 1, N
784       //    DO 2 J = 1, NCOLS
785       //   2 RES(J) = RES(J) + LHS(K)*RHS(K,J)
786       auto genVectorMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder,
787                                  mlir::ValueRange oneBasedIndices,
788                                  mlir::ValueRange reductionArgs)
789           -> llvm::SmallVector<mlir::Value, 0> {
790         mlir::Value J = oneBasedIndices[0];
791         mlir::Value K = oneBasedIndices[1];
792         hlfir::Entity resultElement =
793             hlfir::getElementAt(loc, builder, result, {J});
794         hlfir::Entity resultElementValue =
795             hlfir::loadTrivialScalar(loc, builder, resultElement);
796         hlfir::Entity lhsElementValue =
797             hlfir::loadElementAt(loc, builder, lhs, {K});
798         hlfir::Entity rhsElementValue =
799             hlfir::loadElementAt(loc, builder, rhs, {K, J});
800         mlir::Value productValue =
801             ProductFactory{loc, builder}.genAccumulateProduct(
802                 resultElementValue, lhsElementValue, rhsElementValue);
803         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
804         return {};
805       };
806       hlfir::genLoopNestWithReductions(
807           loc, builder, {resultExtents[0], innerProductExtent},
808           /*reductionInits=*/{}, genVectorMatrix, isUnordered);
809       return mlir::success();
810     }
811 
812     llvm_unreachable("unsupported MATMUL arguments' ranks");
813   }
814 
815   static hlfir::ElementalOp
816   genElementalMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
817                      hlfir::ExprType resultType, mlir::Value resultShape,
818                      hlfir::Entity lhs, hlfir::Entity rhs,
819                      mlir::Value innerProductExtent) {
820     mlir::OpBuilder::InsertionGuard guard(builder);
821     mlir::Type resultElementType = resultType.getElementType();
822     auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
823                          mlir::ValueRange resultIndices) -> hlfir::Entity {
824       mlir::Value initValue =
825           fir::factory::createZeroValue(builder, loc, resultElementType);
826       // The inner product loop may be unordered if FastMathFlags::reassoc
827       // transformations are allowed. The integer/logical inner product is
828       // always unordered.
829       bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
830                          mlir::isa<fir::LogicalType>(resultElementType) ||
831                          static_cast<bool>(builder.getFastMathFlags() &
832                                            mlir::arith::FastMathFlags::reassoc);
833 
834       auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
835                          mlir::ValueRange oneBasedIndices,
836                          mlir::ValueRange reductionArgs)
837           -> llvm::SmallVector<mlir::Value, 1> {
838         llvm::SmallVector<mlir::Value, 2> lhsIndices;
839         llvm::SmallVector<mlir::Value, 2> rhsIndices;
840         // MATMUL:
841         //   LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
842         //   LHS(NROWS,N) * RHS(N) -> RESULT(NROWS)
843         //   LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS)
844         //
845         // MATMUL(TRANSPOSE):
846         //   TRANSPOSE(LHS(N,NROWS)) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
847         //   TRANSPOSE(LHS(N,NROWS)) * RHS(N) -> RESULT(NROWS)
848         //
849         // The resultIndices iterate over (NROWS[,NCOLS]).
850         // The oneBasedIndices iterate over (N).
851         if (lhs.getRank() > 1)
852           lhsIndices.push_back(resultIndices[0]);
853         lhsIndices.push_back(oneBasedIndices[0]);
854 
855         if constexpr (isMatmulTranspose) {
856           // Swap the LHS indices for TRANSPOSE.
857           std::swap(lhsIndices[0], lhsIndices[1]);
858         }
859 
860         rhsIndices.push_back(oneBasedIndices[0]);
861         if (rhs.getRank() > 1)
862           rhsIndices.push_back(resultIndices.back());
863 
864         hlfir::Entity lhsElementValue =
865             hlfir::loadElementAt(loc, builder, lhs, lhsIndices);
866         hlfir::Entity rhsElementValue =
867             hlfir::loadElementAt(loc, builder, rhs, rhsIndices);
868         mlir::Value productValue =
869             ProductFactory{loc, builder}.genAccumulateProduct(
870                 reductionArgs[0], lhsElementValue, rhsElementValue);
871         return {productValue};
872       };
873       llvm::SmallVector<mlir::Value, 1> innerProductValue =
874           hlfir::genLoopNestWithReductions(loc, builder, {innerProductExtent},
875                                            {initValue}, genBody, isUnordered);
876       return hlfir::Entity{innerProductValue[0]};
877     };
878     hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
879         loc, builder, resultElementType, resultShape, /*typeParams=*/{},
880         genKernel,
881         /*isUnordered=*/true, /*polymorphicMold=*/nullptr, resultType);
882 
883     return elementalOp;
884   }
885 };
886 
887 class DotProductConversion
888     : public mlir::OpRewritePattern<hlfir::DotProductOp> {
889 public:
890   using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern;
891 
892   llvm::LogicalResult
893   matchAndRewrite(hlfir::DotProductOp product,
894                   mlir::PatternRewriter &rewriter) const override {
895     hlfir::Entity op = hlfir::Entity{product};
896     if (!op.isScalar())
897       return rewriter.notifyMatchFailure(product, "produces non-scalar result");
898 
899     mlir::Location loc = product.getLoc();
900     fir::FirOpBuilder builder{rewriter, product.getOperation()};
901     hlfir::Entity lhs = hlfir::Entity{product.getLhs()};
902     hlfir::Entity rhs = hlfir::Entity{product.getRhs()};
903     mlir::Type resultElementType = product.getType();
904     bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
905                        mlir::isa<fir::LogicalType>(resultElementType) ||
906                        static_cast<bool>(builder.getFastMathFlags() &
907                                          mlir::arith::FastMathFlags::reassoc);
908 
909     mlir::Value extent = genProductExtent(loc, builder, lhs, rhs);
910 
911     auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
912                        mlir::ValueRange oneBasedIndices,
913                        mlir::ValueRange reductionArgs)
914         -> llvm::SmallVector<mlir::Value, 1> {
915       hlfir::Entity lhsElementValue =
916           hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices);
917       hlfir::Entity rhsElementValue =
918           hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices);
919       mlir::Value productValue =
920           ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>(
921               reductionArgs[0], lhsElementValue, rhsElementValue);
922       return {productValue};
923     };
924 
925     mlir::Value initValue =
926         fir::factory::createZeroValue(builder, loc, resultElementType);
927 
928     llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions(
929         loc, builder, {extent},
930         /*reductionInits=*/{initValue}, genBody, isUnordered);
931 
932     rewriter.replaceOp(product, result[0]);
933     return mlir::success();
934   }
935 
936 private:
937   static mlir::Value genProductExtent(mlir::Location loc,
938                                       fir::FirOpBuilder &builder,
939                                       hlfir::Entity input1,
940                                       hlfir::Entity input2) {
941     llvm::SmallVector<mlir::Value, 1> input1Extents =
942         hlfir::genExtentsVector(loc, builder, input1);
943     llvm::SmallVector<mlir::Value, 1> input2Extents =
944         hlfir::genExtentsVector(loc, builder, input2);
945 
946     assert(input1Extents.size() == 1 && input2Extents.size() == 1 &&
947            "hlfir.dot_product arguments must be vectors");
948     llvm::SmallVector<mlir::Value, 1> extent =
949         fir::factory::deduceOptimalExtents(input1Extents, input2Extents);
950     return extent[0];
951   }
952 };
953 
954 class SimplifyHLFIRIntrinsics
955     : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
956 public:
957   using SimplifyHLFIRIntrinsicsBase<
958       SimplifyHLFIRIntrinsics>::SimplifyHLFIRIntrinsicsBase;
959 
960   void runOnOperation() override {
961     mlir::MLIRContext *context = &getContext();
962 
963     mlir::GreedyRewriteConfig config;
964     // Prevent the pattern driver from merging blocks
965     config.enableRegionSimplification =
966         mlir::GreedySimplifyRegionLevel::Disabled;
967 
968     mlir::RewritePatternSet patterns(context);
969     patterns.insert<TransposeAsElementalConversion>(context);
970     patterns.insert<SumAsElementalConversion>(context);
971     patterns.insert<CShiftAsElementalConversion>(context);
972     patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
973 
974     // If forceMatmulAsElemental is false, then hlfir.matmul inlining
975     // will introduce hlfir.eval_in_mem operation with new memory side
976     // effects. This conflicts with CSE and optimized bufferization, e.g.:
977     //   A(1:N,1:N) =  A(1:N,1:N) - MATMUL(...)
978     // If we introduce hlfir.eval_in_mem before CSE, then the current
979     // MLIR CSE won't be able to optimize the trivial loads of 'N' value
980     // that happen before and after hlfir.matmul.
981     // If 'N' loads are not optimized, then the optimized bufferization
982     // won't be able to prove that the slices of A are identical
983     // on both sides of the assignment.
984     // This is actually the CSE problem, but we can work it around
985     // for the time being.
986     if (forceMatmulAsElemental || this->allowNewSideEffects)
987       patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
988 
989     patterns.insert<DotProductConversion>(context);
990 
991     if (mlir::failed(mlir::applyPatternsGreedily(
992             getOperation(), std::move(patterns), config))) {
993       mlir::emitError(getOperation()->getLoc(),
994                       "failure in HLFIR intrinsic simplification");
995       signalPassFailure();
996     }
997   }
998 };
999 } // namespace
1000