xref: /llvm-project/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp (revision 1fd8d3fea53e6e4573cdce55bd38ef0a7813a442)
12861856bSBenjamin Maxwell //===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
22861856bSBenjamin Maxwell //
32861856bSBenjamin Maxwell // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42861856bSBenjamin Maxwell // See https://llvm.org/LICENSE.txt for license information.
52861856bSBenjamin Maxwell // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62861856bSBenjamin Maxwell //
72861856bSBenjamin Maxwell //===----------------------------------------------------------------------===//
82861856bSBenjamin Maxwell 
92861856bSBenjamin Maxwell #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
102861856bSBenjamin Maxwell #include "mlir/Dialect/Vector/IR/VectorOps.h"
1168a19440SBenjamin Maxwell 
122861856bSBenjamin Maxwell namespace mlir::vector {
132861856bSBenjamin Maxwell 
142861856bSBenjamin Maxwell FailureOr<ConstantOrScalableBound::BoundSize>
152861856bSBenjamin Maxwell ConstantOrScalableBound::getSize() const {
162861856bSBenjamin Maxwell   if (map.isSingleConstant())
172861856bSBenjamin Maxwell     return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
182861856bSBenjamin Maxwell   if (map.getNumResults() != 1 || map.getNumInputs() != 1)
192861856bSBenjamin Maxwell     return failure();
202861856bSBenjamin Maxwell   auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
212861856bSBenjamin Maxwell   if (!binop || binop.getKind() != AffineExprKind::Mul)
222861856bSBenjamin Maxwell     return failure();
232861856bSBenjamin Maxwell   auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
242861856bSBenjamin Maxwell     if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
252861856bSBenjamin Maxwell       constant = cst.getValue();
262861856bSBenjamin Maxwell       return true;
272861856bSBenjamin Maxwell     }
282861856bSBenjamin Maxwell     return false;
292861856bSBenjamin Maxwell   };
302861856bSBenjamin Maxwell   // Match `s0 * cst` or `cst * s0`:
312861856bSBenjamin Maxwell   int64_t cst = 0;
322861856bSBenjamin Maxwell   auto lhs = binop.getLHS();
332861856bSBenjamin Maxwell   auto rhs = binop.getRHS();
342861856bSBenjamin Maxwell   if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
352861856bSBenjamin Maxwell       (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
362861856bSBenjamin Maxwell     return BoundSize{cst, /*scalable=*/true};
372861856bSBenjamin Maxwell   }
382861856bSBenjamin Maxwell   return failure();
392861856bSBenjamin Maxwell }
402861856bSBenjamin Maxwell 
412861856bSBenjamin Maxwell char ScalableValueBoundsConstraintSet::ID = 0;
422861856bSBenjamin Maxwell 
432861856bSBenjamin Maxwell FailureOr<ConstantOrScalableBound>
442861856bSBenjamin Maxwell ScalableValueBoundsConstraintSet::computeScalableBound(
452861856bSBenjamin Maxwell     Value value, std::optional<int64_t> dim, unsigned vscaleMin,
462861856bSBenjamin Maxwell     unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
472861856bSBenjamin Maxwell     StopConditionFn stopCondition) {
482861856bSBenjamin Maxwell   using namespace presburger;
492861856bSBenjamin Maxwell   assert(vscaleMin <= vscaleMax);
502861856bSBenjamin Maxwell 
515e4a4438SMatthias Springer   // No stop condition specified: Keep adding constraints until the worklist
525e4a4438SMatthias Springer   // is empty.
535e4a4438SMatthias Springer   auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
545e4a4438SMatthias Springer                                   mlir::ValueBoundsConstraintSet &cstr) {
555e4a4438SMatthias Springer     return false;
565e4a4438SMatthias Springer   };
575e4a4438SMatthias Springer 
585e4a4438SMatthias Springer   ScalableValueBoundsConstraintSet scalableCstr(
595e4a4438SMatthias Springer       value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
605e4a4438SMatthias Springer       vscaleMin, vscaleMax);
6176435f2dSMatthias Springer   int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
6276435f2dSMatthias Springer   scalableCstr.processWorklist();
632861856bSBenjamin Maxwell 
6429a925abSBenjamin Maxwell   // Check the resulting constraints set is valid.
6529a925abSBenjamin Maxwell   if (scalableCstr.cstr.isEmpty()) {
6629a925abSBenjamin Maxwell     return failure();
6729a925abSBenjamin Maxwell   }
6829a925abSBenjamin Maxwell 
6976435f2dSMatthias Springer   // Project out all columns apart from vscale and the starting point
7076435f2dSMatthias Springer   // (value/dim). This should result in constraints in terms of vscale only.
715e4a4438SMatthias Springer   auto projectOutFn = [&](ValueDim p) {
7276435f2dSMatthias Springer     bool isStartingPoint =
7376435f2dSMatthias Springer         p.first == value &&
7476435f2dSMatthias Springer         p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue);
7576435f2dSMatthias Springer     return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
765e4a4438SMatthias Springer   };
775e4a4438SMatthias Springer   scalableCstr.projectOut(projectOutFn);
7868a19440SBenjamin Maxwell   scalableCstr.projectOutAnonymous(/*except=*/pos);
7929a925abSBenjamin Maxwell   // Also project out local variables (these are not tracked by the
8029a925abSBenjamin Maxwell   // ValueBoundsConstraintSet).
8129a925abSBenjamin Maxwell   for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) {
8229a925abSBenjamin Maxwell     scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars());
8329a925abSBenjamin Maxwell   }
842861856bSBenjamin Maxwell 
852861856bSBenjamin Maxwell   assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
862861856bSBenjamin Maxwell              scalableCstr.positionToValueDim.size() &&
872861856bSBenjamin Maxwell          "inconsistent mapping state");
882861856bSBenjamin Maxwell 
8976435f2dSMatthias Springer   // Check that the only columns left are vscale and the starting point.
902861856bSBenjamin Maxwell   for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
912861856bSBenjamin Maxwell     if (i == pos)
922861856bSBenjamin Maxwell       continue;
932861856bSBenjamin Maxwell     if (scalableCstr.positionToValueDim[i] !=
942861856bSBenjamin Maxwell         ValueDim(scalableCstr.getVscaleValue(),
952861856bSBenjamin Maxwell                  ValueBoundsConstraintSet::kIndexValue)) {
962861856bSBenjamin Maxwell       return failure();
972861856bSBenjamin Maxwell     }
982861856bSBenjamin Maxwell   }
992861856bSBenjamin Maxwell 
1002861856bSBenjamin Maxwell   SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
1012861856bSBenjamin Maxwell   scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
1022861856bSBenjamin Maxwell                                    &upperBound, closedUB);
1032861856bSBenjamin Maxwell 
1042861856bSBenjamin Maxwell   auto invalidBound = [](auto &bound) {
1052861856bSBenjamin Maxwell     return !bound[0] || bound[0].getNumResults() != 1;
1062861856bSBenjamin Maxwell   };
1072861856bSBenjamin Maxwell 
1082861856bSBenjamin Maxwell   AffineMap bound = [&] {
1092861856bSBenjamin Maxwell     if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
110*1fd8d3feSChuvak         lowerBound[0] == upperBound[0]) {
1112861856bSBenjamin Maxwell       return lowerBound[0];
1122861856bSBenjamin Maxwell     } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
1132861856bSBenjamin Maxwell       return lowerBound[0];
1142861856bSBenjamin Maxwell     } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
1152861856bSBenjamin Maxwell       return upperBound[0];
1162861856bSBenjamin Maxwell     }
1172861856bSBenjamin Maxwell     return AffineMap{};
1182861856bSBenjamin Maxwell   }();
1192861856bSBenjamin Maxwell 
1202861856bSBenjamin Maxwell   if (!bound)
1212861856bSBenjamin Maxwell     return failure();
1222861856bSBenjamin Maxwell 
1232861856bSBenjamin Maxwell   return ConstantOrScalableBound{bound};
1242861856bSBenjamin Maxwell }
1252861856bSBenjamin Maxwell 
1262861856bSBenjamin Maxwell } // namespace mlir::vector
127