//===-- VectorSubscripts.cpp -- Vector subscripts tools -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "flang/Lower/VectorSubscripts.h" #include "flang/Lower/AbstractConverter.h" #include "flang/Lower/Support/Utils.h" #include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/Builder/Complex.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Semantics/expression.h" namespace { /// Helper class to lower a designator containing vector subscripts into a /// lowered representation that can be worked with. class VectorSubscriptBoxBuilder { public: VectorSubscriptBoxBuilder(mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext &stmtCtx) : converter{converter}, stmtCtx{stmtCtx}, loc{loc} {} Fortran::lower::VectorSubscriptBox gen(const Fortran::lower::SomeExpr &expr) { elementType = genDesignator(expr); return Fortran::lower::VectorSubscriptBox( std::move(loweredBase), std::move(loweredSubscripts), std::move(componentPath), substringBounds, elementType); } private: using LoweredVectorSubscript = Fortran::lower::VectorSubscriptBox::LoweredVectorSubscript; using LoweredTriplet = Fortran::lower::VectorSubscriptBox::LoweredTriplet; using LoweredSubscript = Fortran::lower::VectorSubscriptBox::LoweredSubscript; using MaybeSubstring = Fortran::lower::VectorSubscriptBox::MaybeSubstring; /// genDesignator unwraps a Designator and calls `gen` on what the /// designator actually contains. template mlir::Type genDesignator(const A &) { fir::emitFatalError(loc, "expr must contain a designator"); } template mlir::Type genDesignator(const Fortran::evaluate::Expr &expr) { using ExprVariant = decltype(Fortran::evaluate::Expr::u); using Designator = Fortran::evaluate::Designator; if constexpr (Fortran::common::HasMember) { const auto &designator = std::get(expr.u); return Fortran::common::visit([&](const auto &x) { return gen(x); }, designator.u); } else { return Fortran::common::visit( [&](const auto &x) { return genDesignator(x); }, expr.u); } } // The gen(X) methods visit X to lower its base and subscripts and return the // type of X elements. mlir::Type gen(const Fortran::evaluate::DataRef &dataRef) { return Fortran::common::visit( [&](const auto &ref) -> mlir::Type { return gen(ref); }, dataRef.u); } mlir::Type gen(const Fortran::evaluate::SymbolRef &symRef) { // Never visited because expr lowering is used to lowered the ranked // ArrayRef. fir::emitFatalError( loc, "expected at least one ArrayRef with vector susbcripts"); } mlir::Type gen(const Fortran::evaluate::Substring &substring) { // StaticDataObject::Pointer bases are constants and cannot be // subscripted, so the base must be a DataRef here. mlir::Type baseElementType = gen(std::get(substring.parent())); fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::Type idxTy = builder.getIndexType(); mlir::Value lb = genScalarValue(substring.lower()); substringBounds.emplace_back(builder.createConvert(loc, idxTy, lb)); if (const auto &ubExpr = substring.upper()) { mlir::Value ub = genScalarValue(*ubExpr); substringBounds.emplace_back(builder.createConvert(loc, idxTy, ub)); } return baseElementType; } mlir::Type gen(const Fortran::evaluate::ComplexPart &complexPart) { auto complexType = gen(complexPart.complex()); fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::Type i32Ty = builder.getI32Type(); // llvm's GEP requires i32 mlir::Value offset = builder.createIntegerConstant( loc, i32Ty, complexPart.part() == Fortran::evaluate::ComplexPart::Part::RE ? 0 : 1); componentPath.emplace_back(offset); return fir::factory::Complex{builder, loc}.getComplexPartType(complexType); } mlir::Type gen(const Fortran::evaluate::Component &component) { auto recTy = mlir::cast(gen(component.base())); const Fortran::semantics::Symbol &componentSymbol = component.GetLastSymbol(); // Parent components will not be found here, they are not part // of the FIR type and cannot be used in the path yet. if (componentSymbol.test(Fortran::semantics::Symbol::Flag::ParentComp)) TODO(loc, "reference to parent component"); mlir::Type fldTy = fir::FieldType::get(&converter.getMLIRContext()); llvm::StringRef componentName = toStringRef(componentSymbol.name()); // Parameters threading in field_index is not yet very clear. We only // have the ones of the ranked array ref at hand, but it looks like // the fir.field_index expects the one of the direct base. if (recTy.getNumLenParams() != 0) TODO(loc, "threading length parameters in field index op"); fir::FirOpBuilder &builder = converter.getFirOpBuilder(); componentPath.emplace_back(builder.create( loc, fldTy, componentName, recTy, /*typeParams*/ std::nullopt)); return fir::unwrapSequenceType(recTy.getType(componentName)); } mlir::Type gen(const Fortran::evaluate::ArrayRef &arrayRef) { auto isTripletOrVector = [](const Fortran::evaluate::Subscript &subscript) -> bool { return Fortran::common::visit( Fortran::common::visitors{ [](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) { return expr.value().Rank() != 0; }, [&](const Fortran::evaluate::Triplet &) { return true; }}, subscript.u); }; if (llvm::any_of(arrayRef.subscript(), isTripletOrVector)) return genRankedArrayRefSubscriptAndBase(arrayRef); // This is a scalar ArrayRef (only scalar indexes), collect the indexes and // visit the base that must contain another arrayRef with the vector // subscript. mlir::Type elementType = gen(namedEntityToDataRef(arrayRef.base())); for (const Fortran::evaluate::Subscript &subscript : arrayRef.subscript()) { const auto &expr = std::get( subscript.u); componentPath.emplace_back(genScalarValue(expr.value())); } return elementType; } /// Lower the subscripts and base of the ArrayRef that is an array (there must /// be one since there is a vector subscript, and there can only be one /// according to C925). mlir::Type genRankedArrayRefSubscriptAndBase( const Fortran::evaluate::ArrayRef &arrayRef) { // Lower the save the base Fortran::lower::SomeExpr baseExpr = namedEntityToExpr(arrayRef.base()); loweredBase = converter.genExprAddr(baseExpr, stmtCtx); // Lower and save the subscripts fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::Type idxTy = builder.getIndexType(); mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); for (const auto &subscript : llvm::enumerate(arrayRef.subscript())) { Fortran::common::visit( Fortran::common::visitors{ [&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) { if (expr.value().Rank() == 0) { // Simple scalar subscript loweredSubscripts.emplace_back(genScalarValue(expr.value())); } else { // Vector subscript. // Remove conversion if any to avoid temp creation that may // have been added by the front-end to avoid the creation of a // temp array value. auto vector = converter.genExprAddr( ignoreEvConvert(expr.value()), stmtCtx); mlir::Value size = fir::factory::readExtent(builder, loc, vector, /*dim=*/0); size = builder.createConvert(loc, idxTy, size); loweredSubscripts.emplace_back( LoweredVectorSubscript{std::move(vector), size}); } }, [&](const Fortran::evaluate::Triplet &triplet) { mlir::Value lb, ub; if (const auto &lbExpr = triplet.lower()) lb = genScalarValue(*lbExpr); else lb = fir::factory::readLowerBound(builder, loc, loweredBase, subscript.index(), one); if (const auto &ubExpr = triplet.upper()) ub = genScalarValue(*ubExpr); else ub = fir::factory::readExtent(builder, loc, loweredBase, subscript.index()); lb = builder.createConvert(loc, idxTy, lb); ub = builder.createConvert(loc, idxTy, ub); mlir::Value stride = genScalarValue(triplet.stride()); stride = builder.createConvert(loc, idxTy, stride); loweredSubscripts.emplace_back(LoweredTriplet{lb, ub, stride}); }, }, subscript.value().u); } return fir::unwrapSequenceType( fir::unwrapPassByRefType(fir::getBase(loweredBase).getType())); } mlir::Type gen(const Fortran::evaluate::CoarrayRef &) { // Is this possible/legal ? TODO(loc, "coarray: reference to coarray object with vector subscript in " "IO input"); } template mlir::Value genScalarValue(const A &expr) { return fir::getBase(converter.genExprValue(toEvExpr(expr), stmtCtx)); } Fortran::evaluate::DataRef namedEntityToDataRef(const Fortran::evaluate::NamedEntity &namedEntity) { if (namedEntity.IsSymbol()) return Fortran::evaluate::DataRef{namedEntity.GetFirstSymbol()}; return Fortran::evaluate::DataRef{namedEntity.GetComponent()}; } Fortran::lower::SomeExpr namedEntityToExpr(const Fortran::evaluate::NamedEntity &namedEntity) { return Fortran::evaluate::AsGenericExpr(namedEntityToDataRef(namedEntity)) .value(); } Fortran::lower::AbstractConverter &converter; Fortran::lower::StatementContext &stmtCtx; mlir::Location loc; /// Elements of VectorSubscriptBox being built. fir::ExtendedValue loweredBase; llvm::SmallVector loweredSubscripts; llvm::SmallVector componentPath; MaybeSubstring substringBounds; mlir::Type elementType; }; } // namespace Fortran::lower::VectorSubscriptBox Fortran::lower::genVectorSubscriptBox( mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext &stmtCtx, const Fortran::lower::SomeExpr &expr) { return VectorSubscriptBoxBuilder(loc, converter, stmtCtx).gen(expr); } template mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsBase( fir::FirOpBuilder &builder, mlir::Location loc, const Generator &elementalGenerator, [[maybe_unused]] mlir::Value initialCondition) { mlir::Value shape = builder.createShape(loc, loweredBase); mlir::Value slice = createSlice(builder, loc); // Create loop nest for triplets and vector subscripts in column // major order. llvm::SmallVector inductionVariables; LoopType outerLoop; for (auto [lb, ub, step] : genLoopBounds(builder, loc)) { LoopType loop; if constexpr (std::is_same_v) { loop = builder.create(loc, lb, ub, step, initialCondition); initialCondition = loop.getIterateVar(); if (!outerLoop) outerLoop = loop; else builder.create(loc, loop.getResult(0)); } else { loop = builder.create(loc, lb, ub, step, /*unordered=*/false); if (!outerLoop) outerLoop = loop; } builder.setInsertionPointToStart(loop.getBody()); inductionVariables.push_back(loop.getInductionVar()); } assert(outerLoop && !inductionVariables.empty() && "at least one loop should be created"); fir::ExtendedValue elem = getElementAt(builder, loc, shape, slice, inductionVariables); if constexpr (std::is_same_v) { auto res = elementalGenerator(elem); builder.create(loc, res); builder.setInsertionPointAfter(outerLoop); return outerLoop.getResult(0); } else { elementalGenerator(elem); builder.setInsertionPointAfter(outerLoop); return {}; } } void Fortran::lower::VectorSubscriptBox::loopOverElements( fir::FirOpBuilder &builder, mlir::Location loc, const ElementalGenerator &elementalGenerator) { mlir::Value initialCondition; loopOverElementsBase( builder, loc, elementalGenerator, initialCondition); } mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsWhile( fir::FirOpBuilder &builder, mlir::Location loc, const ElementalGeneratorWithBoolReturn &elementalGenerator, mlir::Value initialCondition) { return loopOverElementsBase( builder, loc, elementalGenerator, initialCondition); } mlir::Value Fortran::lower::VectorSubscriptBox::createSlice(fir::FirOpBuilder &builder, mlir::Location loc) { mlir::Type idxTy = builder.getIndexType(); llvm::SmallVector triples; mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); auto undef = builder.create(loc, idxTy); for (const LoweredSubscript &subscript : loweredSubscripts) Fortran::common::visit(Fortran::common::visitors{ [&](const LoweredTriplet &triplet) { triples.emplace_back(triplet.lb); triples.emplace_back(triplet.ub); triples.emplace_back(triplet.stride); }, [&](const LoweredVectorSubscript &vector) { triples.emplace_back(one); triples.emplace_back(vector.size); triples.emplace_back(one); }, [&](const mlir::Value &i) { triples.emplace_back(i); triples.emplace_back(undef); triples.emplace_back(undef); }, }, subscript); return builder.create(loc, triples, componentPath); } llvm::SmallVector> Fortran::lower::VectorSubscriptBox::genLoopBounds(fir::FirOpBuilder &builder, mlir::Location loc) { mlir::Type idxTy = builder.getIndexType(); mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); llvm::SmallVector> bounds; size_t dimension = loweredSubscripts.size(); for (const LoweredSubscript &subscript : llvm::reverse(loweredSubscripts)) { --dimension; if (std::holds_alternative(subscript)) continue; mlir::Value lb, ub, step; if (const auto *triplet = std::get_if(&subscript)) { mlir::Value extent = builder.genExtentFromTriplet( loc, triplet->lb, triplet->ub, triplet->stride, idxTy); mlir::Value baseLb = fir::factory::readLowerBound( builder, loc, loweredBase, dimension, one); baseLb = builder.createConvert(loc, idxTy, baseLb); lb = baseLb; ub = builder.create(loc, idxTy, extent, one); ub = builder.create(loc, idxTy, ub, baseLb); step = one; } else { const auto &vector = std::get(subscript); lb = zero; ub = builder.create(loc, idxTy, vector.size, one); step = one; } bounds.emplace_back(lb, ub, step); } return bounds; } fir::ExtendedValue Fortran::lower::VectorSubscriptBox::getElementAt( fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value shape, mlir::Value slice, mlir::ValueRange inductionVariables) { /// Generate the indexes for the array_coor inside the loops. mlir::Type idxTy = builder.getIndexType(); llvm::SmallVector indexes; size_t inductionIdx = inductionVariables.size() - 1; for (const LoweredSubscript &subscript : loweredSubscripts) Fortran::common::visit( Fortran::common::visitors{ [&](const LoweredTriplet &triplet) { indexes.emplace_back(inductionVariables[inductionIdx--]); }, [&](const LoweredVectorSubscript &vector) { mlir::Value vecIndex = inductionVariables[inductionIdx--]; mlir::Value vecBase = fir::getBase(vector.vector); mlir::Type vecEleTy = fir::unwrapSequenceType( fir::unwrapPassByRefType(vecBase.getType())); mlir::Type refTy = builder.getRefType(vecEleTy); auto vecEltRef = builder.create( loc, refTy, vecBase, vecIndex); auto vecElt = builder.create(loc, vecEleTy, vecEltRef); indexes.emplace_back(builder.createConvert(loc, idxTy, vecElt)); }, [&](const mlir::Value &i) { indexes.emplace_back(builder.createConvert(loc, idxTy, i)); }, }, subscript); mlir::Type refTy = builder.getRefType(getElementType()); auto elementAddr = builder.create( loc, refTy, fir::getBase(loweredBase), shape, slice, indexes, fir::getTypeParams(loweredBase)); fir::ExtendedValue element = fir::factory::arraySectionElementToExtendedValue( builder, loc, loweredBase, elementAddr, slice); if (!substringBounds.empty()) { const fir::CharBoxValue *charBox = element.getCharBox(); assert(charBox && "substring requires CharBox base"); fir::factory::CharacterExprHelper helper{builder, loc}; return helper.createSubstring(*charBox, substringBounds); } return element; }