xref: /llvm-project/flang/lib/Lower/VectorSubscripts.cpp (revision 77d8cfb3c50e3341d65af1f9e442004bbd77af9b)
1 //===-- VectorSubscripts.cpp -- Vector subscripts tools -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/VectorSubscripts.h"
14 #include "flang/Lower/AbstractConverter.h"
15 #include "flang/Lower/Support/Utils.h"
16 #include "flang/Optimizer/Builder/Character.h"
17 #include "flang/Optimizer/Builder/Complex.h"
18 #include "flang/Optimizer/Builder/FIRBuilder.h"
19 #include "flang/Optimizer/Builder/Todo.h"
20 #include "flang/Semantics/expression.h"
21 
22 namespace {
23 /// Helper class to lower a designator containing vector subscripts into a
24 /// lowered representation that can be worked with.
25 class VectorSubscriptBoxBuilder {
26 public:
VectorSubscriptBoxBuilder(mlir::Location loc,Fortran::lower::AbstractConverter & converter,Fortran::lower::StatementContext & stmtCtx)27   VectorSubscriptBoxBuilder(mlir::Location loc,
28                             Fortran::lower::AbstractConverter &converter,
29                             Fortran::lower::StatementContext &stmtCtx)
30       : converter{converter}, stmtCtx{stmtCtx}, loc{loc} {}
31 
gen(const Fortran::lower::SomeExpr & expr)32   Fortran::lower::VectorSubscriptBox gen(const Fortran::lower::SomeExpr &expr) {
33     elementType = genDesignator(expr);
34     return Fortran::lower::VectorSubscriptBox(
35         std::move(loweredBase), std::move(loweredSubscripts),
36         std::move(componentPath), substringBounds, elementType);
37   }
38 
39 private:
40   using LoweredVectorSubscript =
41       Fortran::lower::VectorSubscriptBox::LoweredVectorSubscript;
42   using LoweredTriplet = Fortran::lower::VectorSubscriptBox::LoweredTriplet;
43   using LoweredSubscript = Fortran::lower::VectorSubscriptBox::LoweredSubscript;
44   using MaybeSubstring = Fortran::lower::VectorSubscriptBox::MaybeSubstring;
45 
46   /// genDesignator unwraps a Designator<T> and calls `gen` on what the
47   /// designator actually contains.
48   template <typename A>
genDesignator(const A &)49   mlir::Type genDesignator(const A &) {
50     fir::emitFatalError(loc, "expr must contain a designator");
51   }
52   template <typename T>
genDesignator(const Fortran::evaluate::Expr<T> & expr)53   mlir::Type genDesignator(const Fortran::evaluate::Expr<T> &expr) {
54     using ExprVariant = decltype(Fortran::evaluate::Expr<T>::u);
55     using Designator = Fortran::evaluate::Designator<T>;
56     if constexpr (Fortran::common::HasMember<Designator, ExprVariant>) {
57       const auto &designator = std::get<Designator>(expr.u);
58       return Fortran::common::visit([&](const auto &x) { return gen(x); },
59                                     designator.u);
60     } else {
61       return Fortran::common::visit(
62           [&](const auto &x) { return genDesignator(x); }, expr.u);
63     }
64   }
65 
66   // The gen(X) methods visit X to lower its base and subscripts and return the
67   // type of X elements.
68 
gen(const Fortran::evaluate::DataRef & dataRef)69   mlir::Type gen(const Fortran::evaluate::DataRef &dataRef) {
70     return Fortran::common::visit(
71         [&](const auto &ref) -> mlir::Type { return gen(ref); }, dataRef.u);
72   }
73 
gen(const Fortran::evaluate::SymbolRef & symRef)74   mlir::Type gen(const Fortran::evaluate::SymbolRef &symRef) {
75     // Never visited because expr lowering is used to lowered the ranked
76     // ArrayRef.
77     fir::emitFatalError(
78         loc, "expected at least one ArrayRef with vector susbcripts");
79   }
80 
gen(const Fortran::evaluate::Substring & substring)81   mlir::Type gen(const Fortran::evaluate::Substring &substring) {
82     // StaticDataObject::Pointer bases are constants and cannot be
83     // subscripted, so the base must be a DataRef here.
84     mlir::Type baseElementType =
85         gen(std::get<Fortran::evaluate::DataRef>(substring.parent()));
86     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
87     mlir::Type idxTy = builder.getIndexType();
88     mlir::Value lb = genScalarValue(substring.lower());
89     substringBounds.emplace_back(builder.createConvert(loc, idxTy, lb));
90     if (const auto &ubExpr = substring.upper()) {
91       mlir::Value ub = genScalarValue(*ubExpr);
92       substringBounds.emplace_back(builder.createConvert(loc, idxTy, ub));
93     }
94     return baseElementType;
95   }
96 
gen(const Fortran::evaluate::ComplexPart & complexPart)97   mlir::Type gen(const Fortran::evaluate::ComplexPart &complexPart) {
98     auto complexType = gen(complexPart.complex());
99     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
100     mlir::Type i32Ty = builder.getI32Type(); // llvm's GEP requires i32
101     mlir::Value offset = builder.createIntegerConstant(
102         loc, i32Ty,
103         complexPart.part() == Fortran::evaluate::ComplexPart::Part::RE ? 0 : 1);
104     componentPath.emplace_back(offset);
105     return fir::factory::Complex{builder, loc}.getComplexPartType(complexType);
106   }
107 
gen(const Fortran::evaluate::Component & component)108   mlir::Type gen(const Fortran::evaluate::Component &component) {
109     auto recTy = mlir::cast<fir::RecordType>(gen(component.base()));
110     const Fortran::semantics::Symbol &componentSymbol =
111         component.GetLastSymbol();
112     // Parent components will not be found here, they are not part
113     // of the FIR type and cannot be used in the path yet.
114     if (componentSymbol.test(Fortran::semantics::Symbol::Flag::ParentComp))
115       TODO(loc, "reference to parent component");
116     mlir::Type fldTy = fir::FieldType::get(&converter.getMLIRContext());
117     llvm::StringRef componentName = toStringRef(componentSymbol.name());
118     // Parameters threading in field_index is not yet very clear. We only
119     // have the ones of the ranked array ref at hand, but it looks like
120     // the fir.field_index expects the one of the direct base.
121     if (recTy.getNumLenParams() != 0)
122       TODO(loc, "threading length parameters in field index op");
123     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
124     componentPath.emplace_back(builder.create<fir::FieldIndexOp>(
125         loc, fldTy, componentName, recTy, /*typeParams*/ std::nullopt));
126     return fir::unwrapSequenceType(recTy.getType(componentName));
127   }
128 
gen(const Fortran::evaluate::ArrayRef & arrayRef)129   mlir::Type gen(const Fortran::evaluate::ArrayRef &arrayRef) {
130     auto isTripletOrVector =
131         [](const Fortran::evaluate::Subscript &subscript) -> bool {
132       return Fortran::common::visit(
133           Fortran::common::visitors{
134               [](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) {
135                 return expr.value().Rank() != 0;
136               },
137               [&](const Fortran::evaluate::Triplet &) { return true; }},
138           subscript.u);
139     };
140     if (llvm::any_of(arrayRef.subscript(), isTripletOrVector))
141       return genRankedArrayRefSubscriptAndBase(arrayRef);
142 
143     // This is a scalar ArrayRef (only scalar indexes), collect the indexes and
144     // visit the base that must contain another arrayRef with the vector
145     // subscript.
146     mlir::Type elementType = gen(namedEntityToDataRef(arrayRef.base()));
147     for (const Fortran::evaluate::Subscript &subscript : arrayRef.subscript()) {
148       const auto &expr =
149           std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>(
150               subscript.u);
151       componentPath.emplace_back(genScalarValue(expr.value()));
152     }
153     return elementType;
154   }
155 
156   /// Lower the subscripts and base of the ArrayRef that is an array (there must
157   /// be one since there is a vector subscript, and there can only be one
158   /// according to C925).
genRankedArrayRefSubscriptAndBase(const Fortran::evaluate::ArrayRef & arrayRef)159   mlir::Type genRankedArrayRefSubscriptAndBase(
160       const Fortran::evaluate::ArrayRef &arrayRef) {
161     // Lower the save the base
162     Fortran::lower::SomeExpr baseExpr = namedEntityToExpr(arrayRef.base());
163     loweredBase = converter.genExprAddr(baseExpr, stmtCtx);
164     // Lower and save the subscripts
165     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
166     mlir::Type idxTy = builder.getIndexType();
167     mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
168     for (const auto &subscript : llvm::enumerate(arrayRef.subscript())) {
169       Fortran::common::visit(
170           Fortran::common::visitors{
171               [&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) {
172                 if (expr.value().Rank() == 0) {
173                   // Simple scalar subscript
174                   loweredSubscripts.emplace_back(genScalarValue(expr.value()));
175                 } else {
176                   // Vector subscript.
177                   // Remove conversion if any to avoid temp creation that may
178                   // have been added by the front-end to avoid the creation of a
179                   // temp array value.
180                   auto vector = converter.genExprAddr(
181                       ignoreEvConvert(expr.value()), stmtCtx);
182                   mlir::Value size =
183                       fir::factory::readExtent(builder, loc, vector, /*dim=*/0);
184                   size = builder.createConvert(loc, idxTy, size);
185                   loweredSubscripts.emplace_back(
186                       LoweredVectorSubscript{std::move(vector), size});
187                 }
188               },
189               [&](const Fortran::evaluate::Triplet &triplet) {
190                 mlir::Value lb, ub;
191                 if (const auto &lbExpr = triplet.lower())
192                   lb = genScalarValue(*lbExpr);
193                 else
194                   lb = fir::factory::readLowerBound(builder, loc, loweredBase,
195                                                     subscript.index(), one);
196                 if (const auto &ubExpr = triplet.upper())
197                   ub = genScalarValue(*ubExpr);
198                 else
199                   ub = fir::factory::readExtent(builder, loc, loweredBase,
200                                                 subscript.index());
201                 lb = builder.createConvert(loc, idxTy, lb);
202                 ub = builder.createConvert(loc, idxTy, ub);
203                 mlir::Value stride = genScalarValue(triplet.stride());
204                 stride = builder.createConvert(loc, idxTy, stride);
205                 loweredSubscripts.emplace_back(LoweredTriplet{lb, ub, stride});
206               },
207           },
208           subscript.value().u);
209     }
210     return fir::unwrapSequenceType(
211         fir::unwrapPassByRefType(fir::getBase(loweredBase).getType()));
212   }
213 
gen(const Fortran::evaluate::CoarrayRef &)214   mlir::Type gen(const Fortran::evaluate::CoarrayRef &) {
215     // Is this possible/legal ?
216     TODO(loc, "coarray: reference to coarray object with vector subscript in "
217               "IO input");
218   }
219 
220   template <typename A>
genScalarValue(const A & expr)221   mlir::Value genScalarValue(const A &expr) {
222     return fir::getBase(converter.genExprValue(toEvExpr(expr), stmtCtx));
223   }
224 
225   Fortran::evaluate::DataRef
namedEntityToDataRef(const Fortran::evaluate::NamedEntity & namedEntity)226   namedEntityToDataRef(const Fortran::evaluate::NamedEntity &namedEntity) {
227     if (namedEntity.IsSymbol())
228       return Fortran::evaluate::DataRef{namedEntity.GetFirstSymbol()};
229     return Fortran::evaluate::DataRef{namedEntity.GetComponent()};
230   }
231 
232   Fortran::lower::SomeExpr
namedEntityToExpr(const Fortran::evaluate::NamedEntity & namedEntity)233   namedEntityToExpr(const Fortran::evaluate::NamedEntity &namedEntity) {
234     return Fortran::evaluate::AsGenericExpr(namedEntityToDataRef(namedEntity))
235         .value();
236   }
237 
238   Fortran::lower::AbstractConverter &converter;
239   Fortran::lower::StatementContext &stmtCtx;
240   mlir::Location loc;
241   /// Elements of VectorSubscriptBox being built.
242   fir::ExtendedValue loweredBase;
243   llvm::SmallVector<LoweredSubscript, 16> loweredSubscripts;
244   llvm::SmallVector<mlir::Value> componentPath;
245   MaybeSubstring substringBounds;
246   mlir::Type elementType;
247 };
248 } // namespace
249 
genVectorSubscriptBox(mlir::Location loc,Fortran::lower::AbstractConverter & converter,Fortran::lower::StatementContext & stmtCtx,const Fortran::lower::SomeExpr & expr)250 Fortran::lower::VectorSubscriptBox Fortran::lower::genVectorSubscriptBox(
251     mlir::Location loc, Fortran::lower::AbstractConverter &converter,
252     Fortran::lower::StatementContext &stmtCtx,
253     const Fortran::lower::SomeExpr &expr) {
254   return VectorSubscriptBoxBuilder(loc, converter, stmtCtx).gen(expr);
255 }
256 
257 template <typename LoopType, typename Generator>
loopOverElementsBase(fir::FirOpBuilder & builder,mlir::Location loc,const Generator & elementalGenerator,mlir::Value initialCondition)258 mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsBase(
259     fir::FirOpBuilder &builder, mlir::Location loc,
260     const Generator &elementalGenerator,
261     [[maybe_unused]] mlir::Value initialCondition) {
262   mlir::Value shape = builder.createShape(loc, loweredBase);
263   mlir::Value slice = createSlice(builder, loc);
264 
265   // Create loop nest for triplets and vector subscripts in column
266   // major order.
267   llvm::SmallVector<mlir::Value> inductionVariables;
268   LoopType outerLoop;
269   for (auto [lb, ub, step] : genLoopBounds(builder, loc)) {
270     LoopType loop;
271     if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) {
272       loop =
273           builder.create<fir::IterWhileOp>(loc, lb, ub, step, initialCondition);
274       initialCondition = loop.getIterateVar();
275       if (!outerLoop)
276         outerLoop = loop;
277       else
278         builder.create<fir::ResultOp>(loc, loop.getResult(0));
279     } else {
280       loop =
281           builder.create<fir::DoLoopOp>(loc, lb, ub, step, /*unordered=*/false);
282       if (!outerLoop)
283         outerLoop = loop;
284     }
285     builder.setInsertionPointToStart(loop.getBody());
286     inductionVariables.push_back(loop.getInductionVar());
287   }
288   assert(outerLoop && !inductionVariables.empty() &&
289          "at least one loop should be created");
290 
291   fir::ExtendedValue elem =
292       getElementAt(builder, loc, shape, slice, inductionVariables);
293 
294   if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) {
295     auto res = elementalGenerator(elem);
296     builder.create<fir::ResultOp>(loc, res);
297     builder.setInsertionPointAfter(outerLoop);
298     return outerLoop.getResult(0);
299   } else {
300     elementalGenerator(elem);
301     builder.setInsertionPointAfter(outerLoop);
302     return {};
303   }
304 }
305 
loopOverElements(fir::FirOpBuilder & builder,mlir::Location loc,const ElementalGenerator & elementalGenerator)306 void Fortran::lower::VectorSubscriptBox::loopOverElements(
307     fir::FirOpBuilder &builder, mlir::Location loc,
308     const ElementalGenerator &elementalGenerator) {
309   mlir::Value initialCondition;
310   loopOverElementsBase<fir::DoLoopOp, ElementalGenerator>(
311       builder, loc, elementalGenerator, initialCondition);
312 }
313 
loopOverElementsWhile(fir::FirOpBuilder & builder,mlir::Location loc,const ElementalGeneratorWithBoolReturn & elementalGenerator,mlir::Value initialCondition)314 mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsWhile(
315     fir::FirOpBuilder &builder, mlir::Location loc,
316     const ElementalGeneratorWithBoolReturn &elementalGenerator,
317     mlir::Value initialCondition) {
318   return loopOverElementsBase<fir::IterWhileOp,
319                               ElementalGeneratorWithBoolReturn>(
320       builder, loc, elementalGenerator, initialCondition);
321 }
322 
323 mlir::Value
createSlice(fir::FirOpBuilder & builder,mlir::Location loc)324 Fortran::lower::VectorSubscriptBox::createSlice(fir::FirOpBuilder &builder,
325                                                 mlir::Location loc) {
326   mlir::Type idxTy = builder.getIndexType();
327   llvm::SmallVector<mlir::Value> triples;
328   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
329   auto undef = builder.create<fir::UndefOp>(loc, idxTy);
330   for (const LoweredSubscript &subscript : loweredSubscripts)
331     Fortran::common::visit(Fortran::common::visitors{
332                                [&](const LoweredTriplet &triplet) {
333                                  triples.emplace_back(triplet.lb);
334                                  triples.emplace_back(triplet.ub);
335                                  triples.emplace_back(triplet.stride);
336                                },
337                                [&](const LoweredVectorSubscript &vector) {
338                                  triples.emplace_back(one);
339                                  triples.emplace_back(vector.size);
340                                  triples.emplace_back(one);
341                                },
342                                [&](const mlir::Value &i) {
343                                  triples.emplace_back(i);
344                                  triples.emplace_back(undef);
345                                  triples.emplace_back(undef);
346                                },
347                            },
348                            subscript);
349   return builder.create<fir::SliceOp>(loc, triples, componentPath);
350 }
351 
352 llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>>
genLoopBounds(fir::FirOpBuilder & builder,mlir::Location loc)353 Fortran::lower::VectorSubscriptBox::genLoopBounds(fir::FirOpBuilder &builder,
354                                                   mlir::Location loc) {
355   mlir::Type idxTy = builder.getIndexType();
356   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
357   mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
358   llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> bounds;
359   size_t dimension = loweredSubscripts.size();
360   for (const LoweredSubscript &subscript : llvm::reverse(loweredSubscripts)) {
361     --dimension;
362     if (std::holds_alternative<mlir::Value>(subscript))
363       continue;
364     mlir::Value lb, ub, step;
365     if (const auto *triplet = std::get_if<LoweredTriplet>(&subscript)) {
366       mlir::Value extent = builder.genExtentFromTriplet(
367           loc, triplet->lb, triplet->ub, triplet->stride, idxTy);
368       mlir::Value baseLb = fir::factory::readLowerBound(
369           builder, loc, loweredBase, dimension, one);
370       baseLb = builder.createConvert(loc, idxTy, baseLb);
371       lb = baseLb;
372       ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, extent, one);
373       ub = builder.create<mlir::arith::AddIOp>(loc, idxTy, ub, baseLb);
374       step = one;
375     } else {
376       const auto &vector = std::get<LoweredVectorSubscript>(subscript);
377       lb = zero;
378       ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, vector.size, one);
379       step = one;
380     }
381     bounds.emplace_back(lb, ub, step);
382   }
383   return bounds;
384 }
385 
getElementAt(fir::FirOpBuilder & builder,mlir::Location loc,mlir::Value shape,mlir::Value slice,mlir::ValueRange inductionVariables)386 fir::ExtendedValue Fortran::lower::VectorSubscriptBox::getElementAt(
387     fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value shape,
388     mlir::Value slice, mlir::ValueRange inductionVariables) {
389   /// Generate the indexes for the array_coor inside the loops.
390   mlir::Type idxTy = builder.getIndexType();
391   llvm::SmallVector<mlir::Value> indexes;
392   size_t inductionIdx = inductionVariables.size() - 1;
393   for (const LoweredSubscript &subscript : loweredSubscripts)
394     Fortran::common::visit(
395         Fortran::common::visitors{
396             [&](const LoweredTriplet &triplet) {
397               indexes.emplace_back(inductionVariables[inductionIdx--]);
398             },
399             [&](const LoweredVectorSubscript &vector) {
400               mlir::Value vecIndex = inductionVariables[inductionIdx--];
401               mlir::Value vecBase = fir::getBase(vector.vector);
402               mlir::Type vecEleTy = fir::unwrapSequenceType(
403                   fir::unwrapPassByRefType(vecBase.getType()));
404               mlir::Type refTy = builder.getRefType(vecEleTy);
405               auto vecEltRef = builder.create<fir::CoordinateOp>(
406                   loc, refTy, vecBase, vecIndex);
407               auto vecElt =
408                   builder.create<fir::LoadOp>(loc, vecEleTy, vecEltRef);
409               indexes.emplace_back(builder.createConvert(loc, idxTy, vecElt));
410             },
411             [&](const mlir::Value &i) {
412               indexes.emplace_back(builder.createConvert(loc, idxTy, i));
413             },
414         },
415         subscript);
416   mlir::Type refTy = builder.getRefType(getElementType());
417   auto elementAddr = builder.create<fir::ArrayCoorOp>(
418       loc, refTy, fir::getBase(loweredBase), shape, slice, indexes,
419       fir::getTypeParams(loweredBase));
420   fir::ExtendedValue element = fir::factory::arraySectionElementToExtendedValue(
421       builder, loc, loweredBase, elementAddr, slice);
422   if (!substringBounds.empty()) {
423     const fir::CharBoxValue *charBox = element.getCharBox();
424     assert(charBox && "substring requires CharBox base");
425     fir::factory::CharacterExprHelper helper{builder, loc};
426     return helper.createSubstring(*charBox, substringBounds);
427   }
428   return element;
429 }
430