12861856bSBenjamin Maxwell //===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===// 22861856bSBenjamin Maxwell // 32861856bSBenjamin Maxwell // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42861856bSBenjamin Maxwell // See https://llvm.org/LICENSE.txt for license information. 52861856bSBenjamin Maxwell // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62861856bSBenjamin Maxwell // 72861856bSBenjamin Maxwell //===----------------------------------------------------------------------===// 82861856bSBenjamin Maxwell 92861856bSBenjamin Maxwell #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" 102861856bSBenjamin Maxwell #include "mlir/Dialect/Vector/IR/VectorOps.h" 1168a19440SBenjamin Maxwell 122861856bSBenjamin Maxwell namespace mlir::vector { 132861856bSBenjamin Maxwell 142861856bSBenjamin Maxwell FailureOr<ConstantOrScalableBound::BoundSize> 152861856bSBenjamin Maxwell ConstantOrScalableBound::getSize() const { 162861856bSBenjamin Maxwell if (map.isSingleConstant()) 172861856bSBenjamin Maxwell return BoundSize{map.getSingleConstantResult(), /*scalable=*/false}; 182861856bSBenjamin Maxwell if (map.getNumResults() != 1 || map.getNumInputs() != 1) 192861856bSBenjamin Maxwell return failure(); 202861856bSBenjamin Maxwell auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0)); 212861856bSBenjamin Maxwell if (!binop || binop.getKind() != AffineExprKind::Mul) 222861856bSBenjamin Maxwell return failure(); 232861856bSBenjamin Maxwell auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool { 242861856bSBenjamin Maxwell if (auto cst = dyn_cast<AffineConstantExpr>(expr)) { 252861856bSBenjamin Maxwell constant = cst.getValue(); 262861856bSBenjamin Maxwell return true; 272861856bSBenjamin Maxwell } 282861856bSBenjamin Maxwell return false; 292861856bSBenjamin Maxwell }; 302861856bSBenjamin Maxwell // Match `s0 * cst` or `cst * s0`: 312861856bSBenjamin Maxwell int64_t cst = 0; 322861856bSBenjamin Maxwell auto lhs = binop.getLHS(); 332861856bSBenjamin Maxwell auto rhs = binop.getRHS(); 342861856bSBenjamin Maxwell if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) || 352861856bSBenjamin Maxwell (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) { 362861856bSBenjamin Maxwell return BoundSize{cst, /*scalable=*/true}; 372861856bSBenjamin Maxwell } 382861856bSBenjamin Maxwell return failure(); 392861856bSBenjamin Maxwell } 402861856bSBenjamin Maxwell 412861856bSBenjamin Maxwell char ScalableValueBoundsConstraintSet::ID = 0; 422861856bSBenjamin Maxwell 432861856bSBenjamin Maxwell FailureOr<ConstantOrScalableBound> 442861856bSBenjamin Maxwell ScalableValueBoundsConstraintSet::computeScalableBound( 452861856bSBenjamin Maxwell Value value, std::optional<int64_t> dim, unsigned vscaleMin, 462861856bSBenjamin Maxwell unsigned vscaleMax, presburger::BoundType boundType, bool closedUB, 472861856bSBenjamin Maxwell StopConditionFn stopCondition) { 482861856bSBenjamin Maxwell using namespace presburger; 492861856bSBenjamin Maxwell assert(vscaleMin <= vscaleMax); 502861856bSBenjamin Maxwell 515e4a4438SMatthias Springer // No stop condition specified: Keep adding constraints until the worklist 525e4a4438SMatthias Springer // is empty. 535e4a4438SMatthias Springer auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim, 545e4a4438SMatthias Springer mlir::ValueBoundsConstraintSet &cstr) { 555e4a4438SMatthias Springer return false; 565e4a4438SMatthias Springer }; 575e4a4438SMatthias Springer 585e4a4438SMatthias Springer ScalableValueBoundsConstraintSet scalableCstr( 595e4a4438SMatthias Springer value.getContext(), stopCondition ? stopCondition : defaultStopCondition, 605e4a4438SMatthias Springer vscaleMin, vscaleMax); 6176435f2dSMatthias Springer int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false); 6276435f2dSMatthias Springer scalableCstr.processWorklist(); 632861856bSBenjamin Maxwell 6429a925abSBenjamin Maxwell // Check the resulting constraints set is valid. 6529a925abSBenjamin Maxwell if (scalableCstr.cstr.isEmpty()) { 6629a925abSBenjamin Maxwell return failure(); 6729a925abSBenjamin Maxwell } 6829a925abSBenjamin Maxwell 6976435f2dSMatthias Springer // Project out all columns apart from vscale and the starting point 7076435f2dSMatthias Springer // (value/dim). This should result in constraints in terms of vscale only. 715e4a4438SMatthias Springer auto projectOutFn = [&](ValueDim p) { 7276435f2dSMatthias Springer bool isStartingPoint = 7376435f2dSMatthias Springer p.first == value && 7476435f2dSMatthias Springer p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue); 7576435f2dSMatthias Springer return p.first != scalableCstr.getVscaleValue() && !isStartingPoint; 765e4a4438SMatthias Springer }; 775e4a4438SMatthias Springer scalableCstr.projectOut(projectOutFn); 7868a19440SBenjamin Maxwell scalableCstr.projectOutAnonymous(/*except=*/pos); 7929a925abSBenjamin Maxwell // Also project out local variables (these are not tracked by the 8029a925abSBenjamin Maxwell // ValueBoundsConstraintSet). 8129a925abSBenjamin Maxwell for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) { 8229a925abSBenjamin Maxwell scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars()); 8329a925abSBenjamin Maxwell } 842861856bSBenjamin Maxwell 852861856bSBenjamin Maxwell assert(scalableCstr.cstr.getNumDimAndSymbolVars() == 862861856bSBenjamin Maxwell scalableCstr.positionToValueDim.size() && 872861856bSBenjamin Maxwell "inconsistent mapping state"); 882861856bSBenjamin Maxwell 8976435f2dSMatthias Springer // Check that the only columns left are vscale and the starting point. 902861856bSBenjamin Maxwell for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) { 912861856bSBenjamin Maxwell if (i == pos) 922861856bSBenjamin Maxwell continue; 932861856bSBenjamin Maxwell if (scalableCstr.positionToValueDim[i] != 942861856bSBenjamin Maxwell ValueDim(scalableCstr.getVscaleValue(), 952861856bSBenjamin Maxwell ValueBoundsConstraintSet::kIndexValue)) { 962861856bSBenjamin Maxwell return failure(); 972861856bSBenjamin Maxwell } 982861856bSBenjamin Maxwell } 992861856bSBenjamin Maxwell 1002861856bSBenjamin Maxwell SmallVector<AffineMap, 1> lowerBound(1), upperBound(1); 1012861856bSBenjamin Maxwell scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound, 1022861856bSBenjamin Maxwell &upperBound, closedUB); 1032861856bSBenjamin Maxwell 1042861856bSBenjamin Maxwell auto invalidBound = [](auto &bound) { 1052861856bSBenjamin Maxwell return !bound[0] || bound[0].getNumResults() != 1; 1062861856bSBenjamin Maxwell }; 1072861856bSBenjamin Maxwell 1082861856bSBenjamin Maxwell AffineMap bound = [&] { 1092861856bSBenjamin Maxwell if (boundType == BoundType::EQ && !invalidBound(lowerBound) && 110*1fd8d3feSChuvak lowerBound[0] == upperBound[0]) { 1112861856bSBenjamin Maxwell return lowerBound[0]; 1122861856bSBenjamin Maxwell } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) { 1132861856bSBenjamin Maxwell return lowerBound[0]; 1142861856bSBenjamin Maxwell } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { 1152861856bSBenjamin Maxwell return upperBound[0]; 1162861856bSBenjamin Maxwell } 1172861856bSBenjamin Maxwell return AffineMap{}; 1182861856bSBenjamin Maxwell }(); 1192861856bSBenjamin Maxwell 1202861856bSBenjamin Maxwell if (!bound) 1212861856bSBenjamin Maxwell return failure(); 1222861856bSBenjamin Maxwell 1232861856bSBenjamin Maxwell return ConstantOrScalableBound{bound}; 1242861856bSBenjamin Maxwell } 1252861856bSBenjamin Maxwell 1262861856bSBenjamin Maxwell } // namespace mlir::vector 127