1 //===-- VectorSubscripts.cpp -- Vector subscripts tools -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "flang/Lower/VectorSubscripts.h"
14 #include "flang/Lower/AbstractConverter.h"
15 #include "flang/Lower/Support/Utils.h"
16 #include "flang/Optimizer/Builder/Character.h"
17 #include "flang/Optimizer/Builder/Complex.h"
18 #include "flang/Optimizer/Builder/FIRBuilder.h"
19 #include "flang/Optimizer/Builder/Todo.h"
20 #include "flang/Semantics/expression.h"
21
22 namespace {
23 /// Helper class to lower a designator containing vector subscripts into a
24 /// lowered representation that can be worked with.
25 class VectorSubscriptBoxBuilder {
26 public:
VectorSubscriptBoxBuilder(mlir::Location loc,Fortran::lower::AbstractConverter & converter,Fortran::lower::StatementContext & stmtCtx)27 VectorSubscriptBoxBuilder(mlir::Location loc,
28 Fortran::lower::AbstractConverter &converter,
29 Fortran::lower::StatementContext &stmtCtx)
30 : converter{converter}, stmtCtx{stmtCtx}, loc{loc} {}
31
gen(const Fortran::lower::SomeExpr & expr)32 Fortran::lower::VectorSubscriptBox gen(const Fortran::lower::SomeExpr &expr) {
33 elementType = genDesignator(expr);
34 return Fortran::lower::VectorSubscriptBox(
35 std::move(loweredBase), std::move(loweredSubscripts),
36 std::move(componentPath), substringBounds, elementType);
37 }
38
39 private:
40 using LoweredVectorSubscript =
41 Fortran::lower::VectorSubscriptBox::LoweredVectorSubscript;
42 using LoweredTriplet = Fortran::lower::VectorSubscriptBox::LoweredTriplet;
43 using LoweredSubscript = Fortran::lower::VectorSubscriptBox::LoweredSubscript;
44 using MaybeSubstring = Fortran::lower::VectorSubscriptBox::MaybeSubstring;
45
46 /// genDesignator unwraps a Designator<T> and calls `gen` on what the
47 /// designator actually contains.
48 template <typename A>
genDesignator(const A &)49 mlir::Type genDesignator(const A &) {
50 fir::emitFatalError(loc, "expr must contain a designator");
51 }
52 template <typename T>
genDesignator(const Fortran::evaluate::Expr<T> & expr)53 mlir::Type genDesignator(const Fortran::evaluate::Expr<T> &expr) {
54 using ExprVariant = decltype(Fortran::evaluate::Expr<T>::u);
55 using Designator = Fortran::evaluate::Designator<T>;
56 if constexpr (Fortran::common::HasMember<Designator, ExprVariant>) {
57 const auto &designator = std::get<Designator>(expr.u);
58 return Fortran::common::visit([&](const auto &x) { return gen(x); },
59 designator.u);
60 } else {
61 return Fortran::common::visit(
62 [&](const auto &x) { return genDesignator(x); }, expr.u);
63 }
64 }
65
66 // The gen(X) methods visit X to lower its base and subscripts and return the
67 // type of X elements.
68
gen(const Fortran::evaluate::DataRef & dataRef)69 mlir::Type gen(const Fortran::evaluate::DataRef &dataRef) {
70 return Fortran::common::visit(
71 [&](const auto &ref) -> mlir::Type { return gen(ref); }, dataRef.u);
72 }
73
gen(const Fortran::evaluate::SymbolRef & symRef)74 mlir::Type gen(const Fortran::evaluate::SymbolRef &symRef) {
75 // Never visited because expr lowering is used to lowered the ranked
76 // ArrayRef.
77 fir::emitFatalError(
78 loc, "expected at least one ArrayRef with vector susbcripts");
79 }
80
gen(const Fortran::evaluate::Substring & substring)81 mlir::Type gen(const Fortran::evaluate::Substring &substring) {
82 // StaticDataObject::Pointer bases are constants and cannot be
83 // subscripted, so the base must be a DataRef here.
84 mlir::Type baseElementType =
85 gen(std::get<Fortran::evaluate::DataRef>(substring.parent()));
86 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
87 mlir::Type idxTy = builder.getIndexType();
88 mlir::Value lb = genScalarValue(substring.lower());
89 substringBounds.emplace_back(builder.createConvert(loc, idxTy, lb));
90 if (const auto &ubExpr = substring.upper()) {
91 mlir::Value ub = genScalarValue(*ubExpr);
92 substringBounds.emplace_back(builder.createConvert(loc, idxTy, ub));
93 }
94 return baseElementType;
95 }
96
gen(const Fortran::evaluate::ComplexPart & complexPart)97 mlir::Type gen(const Fortran::evaluate::ComplexPart &complexPart) {
98 auto complexType = gen(complexPart.complex());
99 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
100 mlir::Type i32Ty = builder.getI32Type(); // llvm's GEP requires i32
101 mlir::Value offset = builder.createIntegerConstant(
102 loc, i32Ty,
103 complexPart.part() == Fortran::evaluate::ComplexPart::Part::RE ? 0 : 1);
104 componentPath.emplace_back(offset);
105 return fir::factory::Complex{builder, loc}.getComplexPartType(complexType);
106 }
107
gen(const Fortran::evaluate::Component & component)108 mlir::Type gen(const Fortran::evaluate::Component &component) {
109 auto recTy = mlir::cast<fir::RecordType>(gen(component.base()));
110 const Fortran::semantics::Symbol &componentSymbol =
111 component.GetLastSymbol();
112 // Parent components will not be found here, they are not part
113 // of the FIR type and cannot be used in the path yet.
114 if (componentSymbol.test(Fortran::semantics::Symbol::Flag::ParentComp))
115 TODO(loc, "reference to parent component");
116 mlir::Type fldTy = fir::FieldType::get(&converter.getMLIRContext());
117 llvm::StringRef componentName = toStringRef(componentSymbol.name());
118 // Parameters threading in field_index is not yet very clear. We only
119 // have the ones of the ranked array ref at hand, but it looks like
120 // the fir.field_index expects the one of the direct base.
121 if (recTy.getNumLenParams() != 0)
122 TODO(loc, "threading length parameters in field index op");
123 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
124 componentPath.emplace_back(builder.create<fir::FieldIndexOp>(
125 loc, fldTy, componentName, recTy, /*typeParams*/ std::nullopt));
126 return fir::unwrapSequenceType(recTy.getType(componentName));
127 }
128
gen(const Fortran::evaluate::ArrayRef & arrayRef)129 mlir::Type gen(const Fortran::evaluate::ArrayRef &arrayRef) {
130 auto isTripletOrVector =
131 [](const Fortran::evaluate::Subscript &subscript) -> bool {
132 return Fortran::common::visit(
133 Fortran::common::visitors{
134 [](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) {
135 return expr.value().Rank() != 0;
136 },
137 [&](const Fortran::evaluate::Triplet &) { return true; }},
138 subscript.u);
139 };
140 if (llvm::any_of(arrayRef.subscript(), isTripletOrVector))
141 return genRankedArrayRefSubscriptAndBase(arrayRef);
142
143 // This is a scalar ArrayRef (only scalar indexes), collect the indexes and
144 // visit the base that must contain another arrayRef with the vector
145 // subscript.
146 mlir::Type elementType = gen(namedEntityToDataRef(arrayRef.base()));
147 for (const Fortran::evaluate::Subscript &subscript : arrayRef.subscript()) {
148 const auto &expr =
149 std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>(
150 subscript.u);
151 componentPath.emplace_back(genScalarValue(expr.value()));
152 }
153 return elementType;
154 }
155
156 /// Lower the subscripts and base of the ArrayRef that is an array (there must
157 /// be one since there is a vector subscript, and there can only be one
158 /// according to C925).
genRankedArrayRefSubscriptAndBase(const Fortran::evaluate::ArrayRef & arrayRef)159 mlir::Type genRankedArrayRefSubscriptAndBase(
160 const Fortran::evaluate::ArrayRef &arrayRef) {
161 // Lower the save the base
162 Fortran::lower::SomeExpr baseExpr = namedEntityToExpr(arrayRef.base());
163 loweredBase = converter.genExprAddr(baseExpr, stmtCtx);
164 // Lower and save the subscripts
165 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
166 mlir::Type idxTy = builder.getIndexType();
167 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
168 for (const auto &subscript : llvm::enumerate(arrayRef.subscript())) {
169 Fortran::common::visit(
170 Fortran::common::visitors{
171 [&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) {
172 if (expr.value().Rank() == 0) {
173 // Simple scalar subscript
174 loweredSubscripts.emplace_back(genScalarValue(expr.value()));
175 } else {
176 // Vector subscript.
177 // Remove conversion if any to avoid temp creation that may
178 // have been added by the front-end to avoid the creation of a
179 // temp array value.
180 auto vector = converter.genExprAddr(
181 ignoreEvConvert(expr.value()), stmtCtx);
182 mlir::Value size =
183 fir::factory::readExtent(builder, loc, vector, /*dim=*/0);
184 size = builder.createConvert(loc, idxTy, size);
185 loweredSubscripts.emplace_back(
186 LoweredVectorSubscript{std::move(vector), size});
187 }
188 },
189 [&](const Fortran::evaluate::Triplet &triplet) {
190 mlir::Value lb, ub;
191 if (const auto &lbExpr = triplet.lower())
192 lb = genScalarValue(*lbExpr);
193 else
194 lb = fir::factory::readLowerBound(builder, loc, loweredBase,
195 subscript.index(), one);
196 if (const auto &ubExpr = triplet.upper())
197 ub = genScalarValue(*ubExpr);
198 else
199 ub = fir::factory::readExtent(builder, loc, loweredBase,
200 subscript.index());
201 lb = builder.createConvert(loc, idxTy, lb);
202 ub = builder.createConvert(loc, idxTy, ub);
203 mlir::Value stride = genScalarValue(triplet.stride());
204 stride = builder.createConvert(loc, idxTy, stride);
205 loweredSubscripts.emplace_back(LoweredTriplet{lb, ub, stride});
206 },
207 },
208 subscript.value().u);
209 }
210 return fir::unwrapSequenceType(
211 fir::unwrapPassByRefType(fir::getBase(loweredBase).getType()));
212 }
213
gen(const Fortran::evaluate::CoarrayRef &)214 mlir::Type gen(const Fortran::evaluate::CoarrayRef &) {
215 // Is this possible/legal ?
216 TODO(loc, "coarray: reference to coarray object with vector subscript in "
217 "IO input");
218 }
219
220 template <typename A>
genScalarValue(const A & expr)221 mlir::Value genScalarValue(const A &expr) {
222 return fir::getBase(converter.genExprValue(toEvExpr(expr), stmtCtx));
223 }
224
225 Fortran::evaluate::DataRef
namedEntityToDataRef(const Fortran::evaluate::NamedEntity & namedEntity)226 namedEntityToDataRef(const Fortran::evaluate::NamedEntity &namedEntity) {
227 if (namedEntity.IsSymbol())
228 return Fortran::evaluate::DataRef{namedEntity.GetFirstSymbol()};
229 return Fortran::evaluate::DataRef{namedEntity.GetComponent()};
230 }
231
232 Fortran::lower::SomeExpr
namedEntityToExpr(const Fortran::evaluate::NamedEntity & namedEntity)233 namedEntityToExpr(const Fortran::evaluate::NamedEntity &namedEntity) {
234 return Fortran::evaluate::AsGenericExpr(namedEntityToDataRef(namedEntity))
235 .value();
236 }
237
238 Fortran::lower::AbstractConverter &converter;
239 Fortran::lower::StatementContext &stmtCtx;
240 mlir::Location loc;
241 /// Elements of VectorSubscriptBox being built.
242 fir::ExtendedValue loweredBase;
243 llvm::SmallVector<LoweredSubscript, 16> loweredSubscripts;
244 llvm::SmallVector<mlir::Value> componentPath;
245 MaybeSubstring substringBounds;
246 mlir::Type elementType;
247 };
248 } // namespace
249
genVectorSubscriptBox(mlir::Location loc,Fortran::lower::AbstractConverter & converter,Fortran::lower::StatementContext & stmtCtx,const Fortran::lower::SomeExpr & expr)250 Fortran::lower::VectorSubscriptBox Fortran::lower::genVectorSubscriptBox(
251 mlir::Location loc, Fortran::lower::AbstractConverter &converter,
252 Fortran::lower::StatementContext &stmtCtx,
253 const Fortran::lower::SomeExpr &expr) {
254 return VectorSubscriptBoxBuilder(loc, converter, stmtCtx).gen(expr);
255 }
256
257 template <typename LoopType, typename Generator>
loopOverElementsBase(fir::FirOpBuilder & builder,mlir::Location loc,const Generator & elementalGenerator,mlir::Value initialCondition)258 mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsBase(
259 fir::FirOpBuilder &builder, mlir::Location loc,
260 const Generator &elementalGenerator,
261 [[maybe_unused]] mlir::Value initialCondition) {
262 mlir::Value shape = builder.createShape(loc, loweredBase);
263 mlir::Value slice = createSlice(builder, loc);
264
265 // Create loop nest for triplets and vector subscripts in column
266 // major order.
267 llvm::SmallVector<mlir::Value> inductionVariables;
268 LoopType outerLoop;
269 for (auto [lb, ub, step] : genLoopBounds(builder, loc)) {
270 LoopType loop;
271 if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) {
272 loop =
273 builder.create<fir::IterWhileOp>(loc, lb, ub, step, initialCondition);
274 initialCondition = loop.getIterateVar();
275 if (!outerLoop)
276 outerLoop = loop;
277 else
278 builder.create<fir::ResultOp>(loc, loop.getResult(0));
279 } else {
280 loop =
281 builder.create<fir::DoLoopOp>(loc, lb, ub, step, /*unordered=*/false);
282 if (!outerLoop)
283 outerLoop = loop;
284 }
285 builder.setInsertionPointToStart(loop.getBody());
286 inductionVariables.push_back(loop.getInductionVar());
287 }
288 assert(outerLoop && !inductionVariables.empty() &&
289 "at least one loop should be created");
290
291 fir::ExtendedValue elem =
292 getElementAt(builder, loc, shape, slice, inductionVariables);
293
294 if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) {
295 auto res = elementalGenerator(elem);
296 builder.create<fir::ResultOp>(loc, res);
297 builder.setInsertionPointAfter(outerLoop);
298 return outerLoop.getResult(0);
299 } else {
300 elementalGenerator(elem);
301 builder.setInsertionPointAfter(outerLoop);
302 return {};
303 }
304 }
305
loopOverElements(fir::FirOpBuilder & builder,mlir::Location loc,const ElementalGenerator & elementalGenerator)306 void Fortran::lower::VectorSubscriptBox::loopOverElements(
307 fir::FirOpBuilder &builder, mlir::Location loc,
308 const ElementalGenerator &elementalGenerator) {
309 mlir::Value initialCondition;
310 loopOverElementsBase<fir::DoLoopOp, ElementalGenerator>(
311 builder, loc, elementalGenerator, initialCondition);
312 }
313
loopOverElementsWhile(fir::FirOpBuilder & builder,mlir::Location loc,const ElementalGeneratorWithBoolReturn & elementalGenerator,mlir::Value initialCondition)314 mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsWhile(
315 fir::FirOpBuilder &builder, mlir::Location loc,
316 const ElementalGeneratorWithBoolReturn &elementalGenerator,
317 mlir::Value initialCondition) {
318 return loopOverElementsBase<fir::IterWhileOp,
319 ElementalGeneratorWithBoolReturn>(
320 builder, loc, elementalGenerator, initialCondition);
321 }
322
323 mlir::Value
createSlice(fir::FirOpBuilder & builder,mlir::Location loc)324 Fortran::lower::VectorSubscriptBox::createSlice(fir::FirOpBuilder &builder,
325 mlir::Location loc) {
326 mlir::Type idxTy = builder.getIndexType();
327 llvm::SmallVector<mlir::Value> triples;
328 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
329 auto undef = builder.create<fir::UndefOp>(loc, idxTy);
330 for (const LoweredSubscript &subscript : loweredSubscripts)
331 Fortran::common::visit(Fortran::common::visitors{
332 [&](const LoweredTriplet &triplet) {
333 triples.emplace_back(triplet.lb);
334 triples.emplace_back(triplet.ub);
335 triples.emplace_back(triplet.stride);
336 },
337 [&](const LoweredVectorSubscript &vector) {
338 triples.emplace_back(one);
339 triples.emplace_back(vector.size);
340 triples.emplace_back(one);
341 },
342 [&](const mlir::Value &i) {
343 triples.emplace_back(i);
344 triples.emplace_back(undef);
345 triples.emplace_back(undef);
346 },
347 },
348 subscript);
349 return builder.create<fir::SliceOp>(loc, triples, componentPath);
350 }
351
352 llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>>
genLoopBounds(fir::FirOpBuilder & builder,mlir::Location loc)353 Fortran::lower::VectorSubscriptBox::genLoopBounds(fir::FirOpBuilder &builder,
354 mlir::Location loc) {
355 mlir::Type idxTy = builder.getIndexType();
356 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
357 mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
358 llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> bounds;
359 size_t dimension = loweredSubscripts.size();
360 for (const LoweredSubscript &subscript : llvm::reverse(loweredSubscripts)) {
361 --dimension;
362 if (std::holds_alternative<mlir::Value>(subscript))
363 continue;
364 mlir::Value lb, ub, step;
365 if (const auto *triplet = std::get_if<LoweredTriplet>(&subscript)) {
366 mlir::Value extent = builder.genExtentFromTriplet(
367 loc, triplet->lb, triplet->ub, triplet->stride, idxTy);
368 mlir::Value baseLb = fir::factory::readLowerBound(
369 builder, loc, loweredBase, dimension, one);
370 baseLb = builder.createConvert(loc, idxTy, baseLb);
371 lb = baseLb;
372 ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, extent, one);
373 ub = builder.create<mlir::arith::AddIOp>(loc, idxTy, ub, baseLb);
374 step = one;
375 } else {
376 const auto &vector = std::get<LoweredVectorSubscript>(subscript);
377 lb = zero;
378 ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, vector.size, one);
379 step = one;
380 }
381 bounds.emplace_back(lb, ub, step);
382 }
383 return bounds;
384 }
385
getElementAt(fir::FirOpBuilder & builder,mlir::Location loc,mlir::Value shape,mlir::Value slice,mlir::ValueRange inductionVariables)386 fir::ExtendedValue Fortran::lower::VectorSubscriptBox::getElementAt(
387 fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value shape,
388 mlir::Value slice, mlir::ValueRange inductionVariables) {
389 /// Generate the indexes for the array_coor inside the loops.
390 mlir::Type idxTy = builder.getIndexType();
391 llvm::SmallVector<mlir::Value> indexes;
392 size_t inductionIdx = inductionVariables.size() - 1;
393 for (const LoweredSubscript &subscript : loweredSubscripts)
394 Fortran::common::visit(
395 Fortran::common::visitors{
396 [&](const LoweredTriplet &triplet) {
397 indexes.emplace_back(inductionVariables[inductionIdx--]);
398 },
399 [&](const LoweredVectorSubscript &vector) {
400 mlir::Value vecIndex = inductionVariables[inductionIdx--];
401 mlir::Value vecBase = fir::getBase(vector.vector);
402 mlir::Type vecEleTy = fir::unwrapSequenceType(
403 fir::unwrapPassByRefType(vecBase.getType()));
404 mlir::Type refTy = builder.getRefType(vecEleTy);
405 auto vecEltRef = builder.create<fir::CoordinateOp>(
406 loc, refTy, vecBase, vecIndex);
407 auto vecElt =
408 builder.create<fir::LoadOp>(loc, vecEleTy, vecEltRef);
409 indexes.emplace_back(builder.createConvert(loc, idxTy, vecElt));
410 },
411 [&](const mlir::Value &i) {
412 indexes.emplace_back(builder.createConvert(loc, idxTy, i));
413 },
414 },
415 subscript);
416 mlir::Type refTy = builder.getRefType(getElementType());
417 auto elementAddr = builder.create<fir::ArrayCoorOp>(
418 loc, refTy, fir::getBase(loweredBase), shape, slice, indexes,
419 fir::getTypeParams(loweredBase));
420 fir::ExtendedValue element = fir::factory::arraySectionElementToExtendedValue(
421 builder, loc, loweredBase, elementAddr, slice);
422 if (!substringBounds.empty()) {
423 const fir::CharBoxValue *charBox = element.getCharBox();
424 assert(charBox && "substring requires CharBox base");
425 fir::factory::CharacterExprHelper helper{builder, loc};
426 return helper.createSubstring(*charBox, substringBounds);
427 }
428 return element;
429 }
430