//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" namespace mlir::vector { FailureOr ConstantOrScalableBound::getSize() const { if (map.isSingleConstant()) return BoundSize{map.getSingleConstantResult(), /*scalable=*/false}; if (map.getNumResults() != 1 || map.getNumInputs() != 1) return failure(); auto binop = dyn_cast(map.getResult(0)); if (!binop || binop.getKind() != AffineExprKind::Mul) return failure(); auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool { if (auto cst = dyn_cast(expr)) { constant = cst.getValue(); return true; } return false; }; // Match `s0 * cst` or `cst * s0`: int64_t cst = 0; auto lhs = binop.getLHS(); auto rhs = binop.getRHS(); if ((matchConstant(lhs, cst) && isa(rhs)) || (matchConstant(rhs, cst) && isa(lhs))) { return BoundSize{cst, /*scalable=*/true}; } return failure(); } char ScalableValueBoundsConstraintSet::ID = 0; FailureOr ScalableValueBoundsConstraintSet::computeScalableBound( Value value, std::optional dim, unsigned vscaleMin, unsigned vscaleMax, presburger::BoundType boundType, bool closedUB, StopConditionFn stopCondition) { using namespace presburger; assert(vscaleMin <= vscaleMax); // No stop condition specified: Keep adding constraints until the worklist // is empty. auto defaultStopCondition = [&](Value v, std::optional dim, mlir::ValueBoundsConstraintSet &cstr) { return false; }; ScalableValueBoundsConstraintSet scalableCstr( value.getContext(), stopCondition ? stopCondition : defaultStopCondition, vscaleMin, vscaleMax); int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false); scalableCstr.processWorklist(); // Check the resulting constraints set is valid. if (scalableCstr.cstr.isEmpty()) { return failure(); } // Project out all columns apart from vscale and the starting point // (value/dim). This should result in constraints in terms of vscale only. auto projectOutFn = [&](ValueDim p) { bool isStartingPoint = p.first == value && p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue); return p.first != scalableCstr.getVscaleValue() && !isStartingPoint; }; scalableCstr.projectOut(projectOutFn); scalableCstr.projectOutAnonymous(/*except=*/pos); // Also project out local variables (these are not tracked by the // ValueBoundsConstraintSet). for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) { scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars()); } assert(scalableCstr.cstr.getNumDimAndSymbolVars() == scalableCstr.positionToValueDim.size() && "inconsistent mapping state"); // Check that the only columns left are vscale and the starting point. for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) { if (i == pos) continue; if (scalableCstr.positionToValueDim[i] != ValueDim(scalableCstr.getVscaleValue(), ValueBoundsConstraintSet::kIndexValue)) { return failure(); } } SmallVector lowerBound(1), upperBound(1); scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound, &upperBound, closedUB); auto invalidBound = [](auto &bound) { return !bound[0] || bound[0].getNumResults() != 1; }; AffineMap bound = [&] { if (boundType == BoundType::EQ && !invalidBound(lowerBound) && lowerBound[0] == upperBound[0]) { return lowerBound[0]; } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) { return lowerBound[0]; } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { return upperBound[0]; } return AffineMap{}; }(); if (!bound) return failure(); return ConstantOrScalableBound{bound}; } } // namespace mlir::vector