xref: /llvm-project/mlir/include/mlir/IR/AffineExpr.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // An affine expression is an affine combination of dimension identifiers and
10 // symbols, including ceildiv/floordiv/mod by a constant integer.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_IR_AFFINEEXPR_H
15 #define MLIR_IR_AFFINEEXPR_H
16 
17 #include "mlir/IR/Visitors.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/DenseMapInfo.h"
20 #include "llvm/ADT/Hashing.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include <type_traits>
24 
25 namespace mlir {
26 
27 class MLIRContext;
28 class AffineMap;
29 class IntegerSet;
30 
31 namespace detail {
32 
33 struct AffineExprStorage;
34 struct AffineBinaryOpExprStorage;
35 struct AffineDimExprStorage;
36 struct AffineConstantExprStorage;
37 
38 } // namespace detail
39 
40 enum class AffineExprKind {
41   Add,
42   /// RHS of mul is always a constant or a symbolic expression.
43   Mul,
44   /// RHS of mod is always a constant or a symbolic expression with a positive
45   /// value.
46   Mod,
47   /// RHS of floordiv is always a constant or a symbolic expression.
48   FloorDiv,
49   /// RHS of ceildiv is always a constant or a symbolic expression.
50   CeilDiv,
51 
52   /// This is a marker for the last affine binary op. The range of binary
53   /// op's is expected to be this element and earlier.
54   LAST_AFFINE_BINARY_OP = CeilDiv,
55 
56   /// Constant integer.
57   Constant,
58   /// Dimensional identifier.
59   DimId,
60   /// Symbolic identifier.
61   SymbolId,
62 };
63 
64 /// Base type for affine expression.
65 /// AffineExpr's are immutable value types with intuitive operators to
66 /// operate on chainable, lightweight compositions.
67 /// An AffineExpr is an interface to the underlying storage type pointer.
68 class AffineExpr {
69 public:
70   using ImplType = detail::AffineExprStorage;
71 
AffineExpr()72   constexpr AffineExpr() {}
AffineExpr(const ImplType * expr)73   /* implicit */ AffineExpr(const ImplType *expr)
74       : expr(const_cast<ImplType *>(expr)) {}
75 
76   bool operator==(AffineExpr other) const { return expr == other.expr; }
77   bool operator!=(AffineExpr other) const { return !(*this == other); }
78   bool operator==(int64_t v) const;
79   bool operator!=(int64_t v) const { return !(*this == v); }
80   explicit operator bool() const { return expr; }
81 
82   bool operator!() const { return expr == nullptr; }
83 
84   template <typename U>
85   [[deprecated("Use llvm::isa<U>() instead")]] constexpr bool isa() const;
86 
87   template <typename U>
88   [[deprecated("Use llvm::dyn_cast<U>() instead")]] U dyn_cast() const;
89 
90   template <typename U>
91   [[deprecated("Use llvm::dyn_cast_or_null<U>() instead")]] U
92   dyn_cast_or_null() const;
93 
94   template <typename U>
95   [[deprecated("Use llvm::cast<U>() instead")]] U cast() const;
96 
97   MLIRContext *getContext() const;
98 
99   /// Return the classification for this type.
100   AffineExprKind getKind() const;
101 
102   void print(raw_ostream &os) const;
103   void dump() const;
104 
105   /// Returns true if this expression is made out of only symbols and
106   /// constants, i.e., it does not involve dimensional identifiers.
107   bool isSymbolicOrConstant() const;
108 
109   /// Returns true if this is a pure affine expression, i.e., multiplication,
110   /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
111   bool isPureAffine() const;
112 
113   /// Returns the greatest known integral divisor of this affine expression. The
114   /// result is always positive.
115   int64_t getLargestKnownDivisor() const;
116 
117   /// Return true if the affine expression is a multiple of 'factor'.
118   bool isMultipleOf(int64_t factor) const;
119 
120   /// Return true if the affine expression involves AffineDimExpr `position`.
121   bool isFunctionOfDim(unsigned position) const;
122 
123   /// Return true if the affine expression involves AffineSymbolExpr `position`.
124   bool isFunctionOfSymbol(unsigned position) const;
125 
126   /// Walk all of the AffineExpr's in this expression in postorder. This allows
127   /// a lambda walk function that can either return `void` or a WalkResult. With
128   /// a WalkResult, interrupting is supported.
129   template <typename FnT, typename RetT = detail::walkResultType<FnT>>
walk(FnT && callback)130   RetT walk(FnT &&callback) const {
131     return walk<RetT>(*this, callback);
132   }
133 
134   /// This method substitutes any uses of dimensions and symbols (e.g.
135   /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
136   /// This is a dense replacement method: a replacement must be specified for
137   /// every single dim and symbol.
138   AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
139                                    ArrayRef<AffineExpr> symReplacements) const;
140 
141   /// Dim-only version of replaceDimsAndSymbols.
142   AffineExpr replaceDims(ArrayRef<AffineExpr> dimReplacements) const;
143 
144   /// Symbol-only version of replaceDimsAndSymbols.
145   AffineExpr replaceSymbols(ArrayRef<AffineExpr> symReplacements) const;
146 
147   /// Sparse replace method. Replace `expr` by `replacement` and return the
148   /// modified expression tree.
149   AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;
150 
151   /// Sparse replace method. If `*this` appears in `map` replaces it by
152   /// `map[*this]` and return the modified expression tree. Otherwise traverse
153   /// `*this` and apply replace with `map` on its subexpressions.
154   AffineExpr replace(const DenseMap<AffineExpr, AffineExpr> &map) const;
155 
156   /// Replace dims[offset ... numDims)
157   /// by dims[offset + shift ... shift + numDims).
158   AffineExpr shiftDims(unsigned numDims, unsigned shift,
159                        unsigned offset = 0) const;
160 
161   /// Replace symbols[offset ... numSymbols)
162   /// by symbols[offset + shift ... shift + numSymbols).
163   AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift,
164                           unsigned offset = 0) const;
165 
166   AffineExpr operator+(int64_t v) const;
167   AffineExpr operator+(AffineExpr other) const;
168   AffineExpr operator-() const;
169   AffineExpr operator-(int64_t v) const;
170   AffineExpr operator-(AffineExpr other) const;
171   AffineExpr operator*(int64_t v) const;
172   AffineExpr operator*(AffineExpr other) const;
173   AffineExpr floorDiv(uint64_t v) const;
174   AffineExpr floorDiv(AffineExpr other) const;
175   AffineExpr ceilDiv(uint64_t v) const;
176   AffineExpr ceilDiv(AffineExpr other) const;
177   AffineExpr operator%(uint64_t v) const;
178   AffineExpr operator%(AffineExpr other) const;
179 
180   /// Compose with an AffineMap.
181   /// Returns the composition of this AffineExpr with `map`.
182   ///
183   /// Prerequisites:
184   /// `this` and `map` are composable, i.e. that the number of AffineDimExpr of
185   /// `this` is smaller than the number of results of `map`. If a result of a
186   /// map does not have a corresponding AffineDimExpr, that result simply does
187   /// not appear in the produced AffineExpr.
188   ///
189   /// Example:
190   ///   expr: `d0 + d2`
191   ///   map:  `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)`
192   ///   returned expr: `d0 * 2 + d1 + d2 + s1`
193   AffineExpr compose(AffineMap map) const;
194 
195   friend ::llvm::hash_code hash_value(AffineExpr arg);
196 
197   /// Methods supporting C API.
getAsOpaquePointer()198   const void *getAsOpaquePointer() const {
199     return static_cast<const void *>(expr);
200   }
getFromOpaquePointer(const void * pointer)201   static AffineExpr getFromOpaquePointer(const void *pointer) {
202     return AffineExpr(
203         reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
204   }
205 
getImpl()206   ImplType *getImpl() const { return expr; }
207 
208 protected:
209   ImplType *expr{nullptr};
210 
211 private:
212   /// A trampoline for the templated non-static AffineExpr::walk method to
213   /// dispatch lambda `callback`'s of either a void result type or a
214   /// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
215   /// should use the regular (non-static) `walk` method.
216   template <typename WalkRetTy>
217   static WalkRetTy walk(AffineExpr e,
218                         function_ref<WalkRetTy(AffineExpr)> callback);
219 };
220 
221 /// Affine binary operation expression. An affine binary operation could be an
222 /// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
223 /// represented through a multiply by -1 and add.) These expressions are always
224 /// constructed in a simplified form. For eg., the LHS and RHS operands can't
225 /// both be constants. There are additional canonicalizing rules depending on
226 /// the op type: see checks in the constructor.
227 class AffineBinaryOpExpr : public AffineExpr {
228 public:
229   using ImplType = detail::AffineBinaryOpExprStorage;
230   /* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
231   AffineExpr getLHS() const;
232   AffineExpr getRHS() const;
233 };
234 
235 /// A dimensional identifier appearing in an affine expression.
236 class AffineDimExpr : public AffineExpr {
237 public:
238   using ImplType = detail::AffineDimExprStorage;
239   /* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
240   unsigned getPosition() const;
241 };
242 
243 /// A symbolic identifier appearing in an affine expression.
244 class AffineSymbolExpr : public AffineExpr {
245 public:
246   using ImplType = detail::AffineDimExprStorage;
247   /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
248   unsigned getPosition() const;
249 };
250 
251 /// An integer constant appearing in affine expression.
252 class AffineConstantExpr : public AffineExpr {
253 public:
254   using ImplType = detail::AffineConstantExprStorage;
255   /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
256   int64_t getValue() const;
257 };
258 
259 /// Make AffineExpr hashable.
hash_value(AffineExpr arg)260 inline ::llvm::hash_code hash_value(AffineExpr arg) {
261   return ::llvm::hash_value(arg.expr);
262 }
263 
264 inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; }
265 inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; }
266 inline AffineExpr operator-(int64_t val, AffineExpr expr) {
267   return expr * (-1) + val;
268 }
269 
270 /// These free functions allow clients of the API to not use classes in detail.
271 AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
272 AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
273 AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
274 SmallVector<AffineExpr> getAffineConstantExprs(ArrayRef<int64_t> constants,
275                                                MLIRContext *context);
276 AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
277                                  AffineExpr rhs);
278 
279 /// Constructs an affine expression from a flat ArrayRef. If there are local
280 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
281 /// products expression, 'localExprs' is expected to have the AffineExpr
282 /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
283 /// format [dims, symbols, locals, constant term].
284 AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
285                                      unsigned numDims, unsigned numSymbols,
286                                      ArrayRef<AffineExpr> localExprs,
287                                      MLIRContext *context);
288 
289 raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
290 
291 template <typename U>
isa()292 constexpr bool AffineExpr::isa() const {
293   if constexpr (std::is_same_v<U, AffineBinaryOpExpr>)
294     return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
295   if constexpr (std::is_same_v<U, AffineDimExpr>)
296     return getKind() == AffineExprKind::DimId;
297   if constexpr (std::is_same_v<U, AffineSymbolExpr>)
298     return getKind() == AffineExprKind::SymbolId;
299   if constexpr (std::is_same_v<U, AffineConstantExpr>)
300     return getKind() == AffineExprKind::Constant;
301 }
302 template <typename U>
dyn_cast()303 U AffineExpr::dyn_cast() const {
304   return llvm::dyn_cast<U>(*this);
305 }
306 template <typename U>
dyn_cast_or_null()307 U AffineExpr::dyn_cast_or_null() const {
308   return llvm::dyn_cast_or_null<U>(*this);
309 }
310 template <typename U>
cast()311 U AffineExpr::cast() const {
312   return llvm::cast<U>(*this);
313 }
314 
315 /// Simplify an affine expression by flattening and some amount of simple
316 /// analysis. This has complexity linear in the number of nodes in 'expr'.
317 /// Returns the simplified expression, which is the same as the input expression
318 /// if it can't be simplified. When `expr` is semi-affine, a simplified
319 /// semi-affine expression is constructed in the sorted order of dimension and
320 /// symbol positions.
321 AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
322                               unsigned numSymbols);
323 
324 namespace detail {
325 template <int N>
bindDims(MLIRContext * ctx)326 void bindDims(MLIRContext *ctx) {}
327 
328 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindDims(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)329 void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
330   e = getAffineDimExpr(N, ctx);
331   bindDims<N + 1, AffineExprTy2 &...>(ctx, exprs...);
332 }
333 
334 template <int N>
bindSymbols(MLIRContext * ctx)335 void bindSymbols(MLIRContext *ctx) {}
336 
337 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindSymbols(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)338 void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
339   e = getAffineSymbolExpr(N, ctx);
340   bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
341 }
342 
343 } // namespace detail
344 
345 /// Bind a list of AffineExpr references to DimExpr at positions:
346 ///   [0 .. sizeof...(exprs)]
347 template <typename... AffineExprTy>
bindDims(MLIRContext * ctx,AffineExprTy &...exprs)348 void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
349   detail::bindDims<0>(ctx, exprs...);
350 }
351 
352 template <typename AffineExprTy>
bindDimsList(MLIRContext * ctx,MutableArrayRef<AffineExprTy> exprs)353 void bindDimsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
354   int idx = 0;
355   for (AffineExprTy &e : exprs)
356     e = getAffineDimExpr(idx++, ctx);
357 }
358 
359 /// Bind a list of AffineExpr references to SymbolExpr at positions:
360 ///   [0 .. sizeof...(exprs)]
361 template <typename... AffineExprTy>
bindSymbols(MLIRContext * ctx,AffineExprTy &...exprs)362 void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
363   detail::bindSymbols<0>(ctx, exprs...);
364 }
365 
366 template <typename AffineExprTy>
bindSymbolsList(MLIRContext * ctx,MutableArrayRef<AffineExprTy> exprs)367 void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
368   int idx = 0;
369   for (AffineExprTy &e : exprs)
370     e = getAffineSymbolExpr(idx++, ctx);
371 }
372 
373 /// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
374 /// the constant lower and upper bounds for its inputs provided in
375 /// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
376 /// bound can't be computed. This method only handles simple sum of product
377 /// expressions (w.r.t constant coefficients) so as to not depend on anything
378 /// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
379 /// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
380 /// semi-affine ones will lead a none being returned.
381 std::optional<int64_t>
382 getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
383                       ArrayRef<std::optional<int64_t>> constLowerBounds,
384                       ArrayRef<std::optional<int64_t>> constUpperBounds,
385                       bool isUpper);
386 
387 } // namespace mlir
388 
389 namespace llvm {
390 
391 // AffineExpr hash just like pointers
392 template <>
393 struct DenseMapInfo<mlir::AffineExpr> {
394   static mlir::AffineExpr getEmptyKey() {
395     auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
396     return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
397   }
398   static mlir::AffineExpr getTombstoneKey() {
399     auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
400     return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
401   }
402   static unsigned getHashValue(mlir::AffineExpr val) {
403     return mlir::hash_value(val);
404   }
405   static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) {
406     return LHS == RHS;
407   }
408 };
409 
410 /// Add support for llvm style casts. We provide a cast between To and From if
411 /// From is mlir::AffineExpr or derives from it.
412 template <typename To, typename From>
413 struct CastInfo<To, From,
414                 std::enable_if_t<std::is_same_v<mlir::AffineExpr,
415                                                 std::remove_const_t<From>> ||
416                                  std::is_base_of_v<mlir::AffineExpr, From>>>
417     : NullableValueCastFailed<To>,
418       DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
419 
420   static inline bool isPossible(mlir::AffineExpr expr) {
421     /// Return a constant true instead of a dynamic true when casting to self or
422     /// up the hierarchy.
423     if constexpr (std::is_base_of_v<To, From>) {
424       return true;
425     } else {
426       if constexpr (std::is_same_v<To, ::mlir::AffineBinaryOpExpr>)
427         return expr.getKind() <= ::mlir::AffineExprKind::LAST_AFFINE_BINARY_OP;
428       if constexpr (std::is_same_v<To, ::mlir::AffineDimExpr>)
429         return expr.getKind() == ::mlir::AffineExprKind::DimId;
430       if constexpr (std::is_same_v<To, ::mlir::AffineSymbolExpr>)
431         return expr.getKind() == ::mlir::AffineExprKind::SymbolId;
432       if constexpr (std::is_same_v<To, ::mlir::AffineConstantExpr>)
433         return expr.getKind() == ::mlir::AffineExprKind::Constant;
434     }
435   }
436   static inline To doCast(mlir::AffineExpr expr) { return To(expr.getImpl()); }
437 };
438 
439 } // namespace llvm
440 
441 #endif // MLIR_IR_AFFINEEXPR_H
442