xref: /llvm-project/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h (revision 29a925abb660104b413b15075b3a19793825f57e)
1 //===- ScalableValueBoundsConstraintSet.h - Scalable Value Bounds ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
10 #define MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
11 
12 #include "mlir/Analysis/Presburger/IntegerRelation.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
15 
16 namespace mlir::vector {
17 
18 namespace detail {
19 
20 /// Parent class for the value bounds RTTIExtends. Uses protected inheritance to
21 /// hide all ValueBoundsConstraintSet methods by default (as some do not use the
22 /// ScalableValueBoundsConstraintSet, so may produce unexpected results).
23 struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet {
24   using ::mlir::ValueBoundsConstraintSet::ValueBoundsConstraintSet;
25 };
26 } // namespace detail
27 
28 /// A version of `ValueBoundsConstraintSet` that can solve for scalable bounds.
29 struct ScalableValueBoundsConstraintSet
30     : public llvm::RTTIExtends<ScalableValueBoundsConstraintSet,
31                                detail::ValueBoundsConstraintSet> {
ScalableValueBoundsConstraintSetScalableValueBoundsConstraintSet32   ScalableValueBoundsConstraintSet(
33       MLIRContext *context,
34       ValueBoundsConstraintSet::StopConditionFn stopCondition,
35       unsigned vscaleMin, unsigned vscaleMax)
36       : RTTIExtends(context, stopCondition,
37                     /*addConservativeSemiAffineBounds=*/true),
38         vscaleMin(vscaleMin), vscaleMax(vscaleMax) {};
39 
40   using RTTIExtends::bound;
41   using RTTIExtends::StopConditionFn;
42 
43   /// A thin wrapper over an `AffineMap` which can represent a constant bound,
44   /// or a scalable bound (in terms of vscale). The `AffineMap` will always
45   /// take at most one parameter, vscale, and returns a single result, which is
46   /// the bound of value.
47   struct ConstantOrScalableBound {
48     AffineMap map;
49 
50     struct BoundSize {
51       int64_t baseSize{0};
52       bool scalable{false};
53     };
54 
55     /// Get the (possibly) scalable size of the bound, returns failure if
56     /// the bound cannot be represented as a single quantity.
57     FailureOr<BoundSize> getSize() const;
58   };
59 
60   /// Computes a (possibly) scalable bound for a given value. This is
61   /// similar to `ValueBoundsConstraintSet::computeConstantBound()`, but
62   /// uses knowledge of the range of vscale to compute either a constant
63   /// bound, an expression in terms of vscale, or failure if no bound can
64   /// be computed.
65   ///
66   /// The resulting `AffineMap` will always take at most one parameter,
67   /// vscale, and return a single result, which is the bound of `value`.
68   ///
69   /// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` ==
70   /// `vscaleMax`, the resulting bound (if found), will be constant.
71   static FailureOr<ConstantOrScalableBound>
72   computeScalableBound(Value value, std::optional<int64_t> dim,
73                        unsigned vscaleMin, unsigned vscaleMax,
74                        presburger::BoundType boundType, bool closedUB = true,
75                        StopConditionFn stopCondition = nullptr);
76 
77   /// Get the value of vscale. Returns `nullptr` vscale as not been encountered.
getVscaleValueScalableValueBoundsConstraintSet78   Value getVscaleValue() const { return vscale; }
79 
80   /// Sets the value of vscale. Asserts if vscale has already been set.
setVscaleScalableValueBoundsConstraintSet81   void setVscale(vector::VectorScaleOp vscaleOp) {
82     assert(!vscale && "expected vscale to be unset");
83     vscale = vscaleOp.getResult();
84   }
85 
86   /// The minimum possible value of vscale.
getVscaleMinScalableValueBoundsConstraintSet87   unsigned getVscaleMin() const { return vscaleMin; }
88 
89   /// The maximum possible value of vscale.
getVscaleMaxScalableValueBoundsConstraintSet90   unsigned getVscaleMax() const { return vscaleMax; }
91 
92   static char ID;
93 
94 private:
95   const unsigned vscaleMin;
96   const unsigned vscaleMax;
97 
98   // This will be set when the first `vector.vscale` operation is found within
99   // the `ValueBoundsOpInterface` implementation then reused from there on.
100   Value vscale = nullptr;
101 };
102 
103 using ConstantOrScalableBound =
104     ScalableValueBoundsConstraintSet::ConstantOrScalableBound;
105 
106 } // namespace mlir::vector
107 
108 #endif // MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
109