xref: /llvm-project/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp (revision 1fd8d3fea53e6e4573cdce55bd38ef0a7813a442)
1 //===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
10 #include "mlir/Dialect/Vector/IR/VectorOps.h"
11 
12 namespace mlir::vector {
13 
14 FailureOr<ConstantOrScalableBound::BoundSize>
15 ConstantOrScalableBound::getSize() const {
16   if (map.isSingleConstant())
17     return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
18   if (map.getNumResults() != 1 || map.getNumInputs() != 1)
19     return failure();
20   auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
21   if (!binop || binop.getKind() != AffineExprKind::Mul)
22     return failure();
23   auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
24     if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
25       constant = cst.getValue();
26       return true;
27     }
28     return false;
29   };
30   // Match `s0 * cst` or `cst * s0`:
31   int64_t cst = 0;
32   auto lhs = binop.getLHS();
33   auto rhs = binop.getRHS();
34   if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
35       (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
36     return BoundSize{cst, /*scalable=*/true};
37   }
38   return failure();
39 }
40 
41 char ScalableValueBoundsConstraintSet::ID = 0;
42 
43 FailureOr<ConstantOrScalableBound>
44 ScalableValueBoundsConstraintSet::computeScalableBound(
45     Value value, std::optional<int64_t> dim, unsigned vscaleMin,
46     unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
47     StopConditionFn stopCondition) {
48   using namespace presburger;
49   assert(vscaleMin <= vscaleMax);
50 
51   // No stop condition specified: Keep adding constraints until the worklist
52   // is empty.
53   auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
54                                   mlir::ValueBoundsConstraintSet &cstr) {
55     return false;
56   };
57 
58   ScalableValueBoundsConstraintSet scalableCstr(
59       value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
60       vscaleMin, vscaleMax);
61   int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
62   scalableCstr.processWorklist();
63 
64   // Check the resulting constraints set is valid.
65   if (scalableCstr.cstr.isEmpty()) {
66     return failure();
67   }
68 
69   // Project out all columns apart from vscale and the starting point
70   // (value/dim). This should result in constraints in terms of vscale only.
71   auto projectOutFn = [&](ValueDim p) {
72     bool isStartingPoint =
73         p.first == value &&
74         p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue);
75     return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
76   };
77   scalableCstr.projectOut(projectOutFn);
78   scalableCstr.projectOutAnonymous(/*except=*/pos);
79   // Also project out local variables (these are not tracked by the
80   // ValueBoundsConstraintSet).
81   for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) {
82     scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars());
83   }
84 
85   assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
86              scalableCstr.positionToValueDim.size() &&
87          "inconsistent mapping state");
88 
89   // Check that the only columns left are vscale and the starting point.
90   for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
91     if (i == pos)
92       continue;
93     if (scalableCstr.positionToValueDim[i] !=
94         ValueDim(scalableCstr.getVscaleValue(),
95                  ValueBoundsConstraintSet::kIndexValue)) {
96       return failure();
97     }
98   }
99 
100   SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
101   scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
102                                    &upperBound, closedUB);
103 
104   auto invalidBound = [](auto &bound) {
105     return !bound[0] || bound[0].getNumResults() != 1;
106   };
107 
108   AffineMap bound = [&] {
109     if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
110         lowerBound[0] == upperBound[0]) {
111       return lowerBound[0];
112     } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
113       return lowerBound[0];
114     } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
115       return upperBound[0];
116     }
117     return AffineMap{};
118   }();
119 
120   if (!bound)
121     return failure();
122 
123   return ConstantOrScalableBound{bound};
124 }
125 
126 } // namespace mlir::vector
127