1 //===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" 10 #include "mlir/Dialect/Vector/IR/VectorOps.h" 11 12 namespace mlir::vector { 13 14 FailureOr<ConstantOrScalableBound::BoundSize> 15 ConstantOrScalableBound::getSize() const { 16 if (map.isSingleConstant()) 17 return BoundSize{map.getSingleConstantResult(), /*scalable=*/false}; 18 if (map.getNumResults() != 1 || map.getNumInputs() != 1) 19 return failure(); 20 auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0)); 21 if (!binop || binop.getKind() != AffineExprKind::Mul) 22 return failure(); 23 auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool { 24 if (auto cst = dyn_cast<AffineConstantExpr>(expr)) { 25 constant = cst.getValue(); 26 return true; 27 } 28 return false; 29 }; 30 // Match `s0 * cst` or `cst * s0`: 31 int64_t cst = 0; 32 auto lhs = binop.getLHS(); 33 auto rhs = binop.getRHS(); 34 if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) || 35 (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) { 36 return BoundSize{cst, /*scalable=*/true}; 37 } 38 return failure(); 39 } 40 41 char ScalableValueBoundsConstraintSet::ID = 0; 42 43 FailureOr<ConstantOrScalableBound> 44 ScalableValueBoundsConstraintSet::computeScalableBound( 45 Value value, std::optional<int64_t> dim, unsigned vscaleMin, 46 unsigned vscaleMax, presburger::BoundType boundType, bool closedUB, 47 StopConditionFn stopCondition) { 48 using namespace presburger; 49 assert(vscaleMin <= vscaleMax); 50 51 // No stop condition specified: Keep adding constraints until the worklist 52 // is empty. 53 auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim, 54 mlir::ValueBoundsConstraintSet &cstr) { 55 return false; 56 }; 57 58 ScalableValueBoundsConstraintSet scalableCstr( 59 value.getContext(), stopCondition ? stopCondition : defaultStopCondition, 60 vscaleMin, vscaleMax); 61 int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false); 62 scalableCstr.processWorklist(); 63 64 // Check the resulting constraints set is valid. 65 if (scalableCstr.cstr.isEmpty()) { 66 return failure(); 67 } 68 69 // Project out all columns apart from vscale and the starting point 70 // (value/dim). This should result in constraints in terms of vscale only. 71 auto projectOutFn = [&](ValueDim p) { 72 bool isStartingPoint = 73 p.first == value && 74 p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue); 75 return p.first != scalableCstr.getVscaleValue() && !isStartingPoint; 76 }; 77 scalableCstr.projectOut(projectOutFn); 78 scalableCstr.projectOutAnonymous(/*except=*/pos); 79 // Also project out local variables (these are not tracked by the 80 // ValueBoundsConstraintSet). 81 for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) { 82 scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars()); 83 } 84 85 assert(scalableCstr.cstr.getNumDimAndSymbolVars() == 86 scalableCstr.positionToValueDim.size() && 87 "inconsistent mapping state"); 88 89 // Check that the only columns left are vscale and the starting point. 90 for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) { 91 if (i == pos) 92 continue; 93 if (scalableCstr.positionToValueDim[i] != 94 ValueDim(scalableCstr.getVscaleValue(), 95 ValueBoundsConstraintSet::kIndexValue)) { 96 return failure(); 97 } 98 } 99 100 SmallVector<AffineMap, 1> lowerBound(1), upperBound(1); 101 scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound, 102 &upperBound, closedUB); 103 104 auto invalidBound = [](auto &bound) { 105 return !bound[0] || bound[0].getNumResults() != 1; 106 }; 107 108 AffineMap bound = [&] { 109 if (boundType == BoundType::EQ && !invalidBound(lowerBound) && 110 lowerBound[0] == upperBound[0]) { 111 return lowerBound[0]; 112 } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) { 113 return lowerBound[0]; 114 } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { 115 return upperBound[0]; 116 } 117 return AffineMap{}; 118 }(); 119 120 if (!bound) 121 return failure(); 122 123 return ConstantOrScalableBound{bound}; 124 } 125 126 } // namespace mlir::vector 127