1 //===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // An affine expression is an affine combination of dimension identifiers and
10 // symbols, including ceildiv/floordiv/mod by a constant integer.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef MLIR_IR_AFFINEEXPR_H
15 #define MLIR_IR_AFFINEEXPR_H
16
17 #include "mlir/IR/Visitors.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/DenseMapInfo.h"
20 #include "llvm/ADT/Hashing.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include <type_traits>
24
25 namespace mlir {
26
27 class MLIRContext;
28 class AffineMap;
29 class IntegerSet;
30
31 namespace detail {
32
33 struct AffineExprStorage;
34 struct AffineBinaryOpExprStorage;
35 struct AffineDimExprStorage;
36 struct AffineConstantExprStorage;
37
38 } // namespace detail
39
40 enum class AffineExprKind {
41 Add,
42 /// RHS of mul is always a constant or a symbolic expression.
43 Mul,
44 /// RHS of mod is always a constant or a symbolic expression with a positive
45 /// value.
46 Mod,
47 /// RHS of floordiv is always a constant or a symbolic expression.
48 FloorDiv,
49 /// RHS of ceildiv is always a constant or a symbolic expression.
50 CeilDiv,
51
52 /// This is a marker for the last affine binary op. The range of binary
53 /// op's is expected to be this element and earlier.
54 LAST_AFFINE_BINARY_OP = CeilDiv,
55
56 /// Constant integer.
57 Constant,
58 /// Dimensional identifier.
59 DimId,
60 /// Symbolic identifier.
61 SymbolId,
62 };
63
64 /// Base type for affine expression.
65 /// AffineExpr's are immutable value types with intuitive operators to
66 /// operate on chainable, lightweight compositions.
67 /// An AffineExpr is an interface to the underlying storage type pointer.
68 class AffineExpr {
69 public:
70 using ImplType = detail::AffineExprStorage;
71
AffineExpr()72 constexpr AffineExpr() {}
AffineExpr(const ImplType * expr)73 /* implicit */ AffineExpr(const ImplType *expr)
74 : expr(const_cast<ImplType *>(expr)) {}
75
76 bool operator==(AffineExpr other) const { return expr == other.expr; }
77 bool operator!=(AffineExpr other) const { return !(*this == other); }
78 bool operator==(int64_t v) const;
79 bool operator!=(int64_t v) const { return !(*this == v); }
80 explicit operator bool() const { return expr; }
81
82 bool operator!() const { return expr == nullptr; }
83
84 template <typename U>
85 [[deprecated("Use llvm::isa<U>() instead")]] constexpr bool isa() const;
86
87 template <typename U>
88 [[deprecated("Use llvm::dyn_cast<U>() instead")]] U dyn_cast() const;
89
90 template <typename U>
91 [[deprecated("Use llvm::dyn_cast_or_null<U>() instead")]] U
92 dyn_cast_or_null() const;
93
94 template <typename U>
95 [[deprecated("Use llvm::cast<U>() instead")]] U cast() const;
96
97 MLIRContext *getContext() const;
98
99 /// Return the classification for this type.
100 AffineExprKind getKind() const;
101
102 void print(raw_ostream &os) const;
103 void dump() const;
104
105 /// Returns true if this expression is made out of only symbols and
106 /// constants, i.e., it does not involve dimensional identifiers.
107 bool isSymbolicOrConstant() const;
108
109 /// Returns true if this is a pure affine expression, i.e., multiplication,
110 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
111 bool isPureAffine() const;
112
113 /// Returns the greatest known integral divisor of this affine expression. The
114 /// result is always positive.
115 int64_t getLargestKnownDivisor() const;
116
117 /// Return true if the affine expression is a multiple of 'factor'.
118 bool isMultipleOf(int64_t factor) const;
119
120 /// Return true if the affine expression involves AffineDimExpr `position`.
121 bool isFunctionOfDim(unsigned position) const;
122
123 /// Return true if the affine expression involves AffineSymbolExpr `position`.
124 bool isFunctionOfSymbol(unsigned position) const;
125
126 /// Walk all of the AffineExpr's in this expression in postorder. This allows
127 /// a lambda walk function that can either return `void` or a WalkResult. With
128 /// a WalkResult, interrupting is supported.
129 template <typename FnT, typename RetT = detail::walkResultType<FnT>>
walk(FnT && callback)130 RetT walk(FnT &&callback) const {
131 return walk<RetT>(*this, callback);
132 }
133
134 /// This method substitutes any uses of dimensions and symbols (e.g.
135 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
136 /// This is a dense replacement method: a replacement must be specified for
137 /// every single dim and symbol.
138 AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
139 ArrayRef<AffineExpr> symReplacements) const;
140
141 /// Dim-only version of replaceDimsAndSymbols.
142 AffineExpr replaceDims(ArrayRef<AffineExpr> dimReplacements) const;
143
144 /// Symbol-only version of replaceDimsAndSymbols.
145 AffineExpr replaceSymbols(ArrayRef<AffineExpr> symReplacements) const;
146
147 /// Sparse replace method. Replace `expr` by `replacement` and return the
148 /// modified expression tree.
149 AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;
150
151 /// Sparse replace method. If `*this` appears in `map` replaces it by
152 /// `map[*this]` and return the modified expression tree. Otherwise traverse
153 /// `*this` and apply replace with `map` on its subexpressions.
154 AffineExpr replace(const DenseMap<AffineExpr, AffineExpr> &map) const;
155
156 /// Replace dims[offset ... numDims)
157 /// by dims[offset + shift ... shift + numDims).
158 AffineExpr shiftDims(unsigned numDims, unsigned shift,
159 unsigned offset = 0) const;
160
161 /// Replace symbols[offset ... numSymbols)
162 /// by symbols[offset + shift ... shift + numSymbols).
163 AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift,
164 unsigned offset = 0) const;
165
166 AffineExpr operator+(int64_t v) const;
167 AffineExpr operator+(AffineExpr other) const;
168 AffineExpr operator-() const;
169 AffineExpr operator-(int64_t v) const;
170 AffineExpr operator-(AffineExpr other) const;
171 AffineExpr operator*(int64_t v) const;
172 AffineExpr operator*(AffineExpr other) const;
173 AffineExpr floorDiv(uint64_t v) const;
174 AffineExpr floorDiv(AffineExpr other) const;
175 AffineExpr ceilDiv(uint64_t v) const;
176 AffineExpr ceilDiv(AffineExpr other) const;
177 AffineExpr operator%(uint64_t v) const;
178 AffineExpr operator%(AffineExpr other) const;
179
180 /// Compose with an AffineMap.
181 /// Returns the composition of this AffineExpr with `map`.
182 ///
183 /// Prerequisites:
184 /// `this` and `map` are composable, i.e. that the number of AffineDimExpr of
185 /// `this` is smaller than the number of results of `map`. If a result of a
186 /// map does not have a corresponding AffineDimExpr, that result simply does
187 /// not appear in the produced AffineExpr.
188 ///
189 /// Example:
190 /// expr: `d0 + d2`
191 /// map: `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)`
192 /// returned expr: `d0 * 2 + d1 + d2 + s1`
193 AffineExpr compose(AffineMap map) const;
194
195 friend ::llvm::hash_code hash_value(AffineExpr arg);
196
197 /// Methods supporting C API.
getAsOpaquePointer()198 const void *getAsOpaquePointer() const {
199 return static_cast<const void *>(expr);
200 }
getFromOpaquePointer(const void * pointer)201 static AffineExpr getFromOpaquePointer(const void *pointer) {
202 return AffineExpr(
203 reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
204 }
205
getImpl()206 ImplType *getImpl() const { return expr; }
207
208 protected:
209 ImplType *expr{nullptr};
210
211 private:
212 /// A trampoline for the templated non-static AffineExpr::walk method to
213 /// dispatch lambda `callback`'s of either a void result type or a
214 /// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
215 /// should use the regular (non-static) `walk` method.
216 template <typename WalkRetTy>
217 static WalkRetTy walk(AffineExpr e,
218 function_ref<WalkRetTy(AffineExpr)> callback);
219 };
220
221 /// Affine binary operation expression. An affine binary operation could be an
222 /// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
223 /// represented through a multiply by -1 and add.) These expressions are always
224 /// constructed in a simplified form. For eg., the LHS and RHS operands can't
225 /// both be constants. There are additional canonicalizing rules depending on
226 /// the op type: see checks in the constructor.
227 class AffineBinaryOpExpr : public AffineExpr {
228 public:
229 using ImplType = detail::AffineBinaryOpExprStorage;
230 /* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
231 AffineExpr getLHS() const;
232 AffineExpr getRHS() const;
233 };
234
235 /// A dimensional identifier appearing in an affine expression.
236 class AffineDimExpr : public AffineExpr {
237 public:
238 using ImplType = detail::AffineDimExprStorage;
239 /* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
240 unsigned getPosition() const;
241 };
242
243 /// A symbolic identifier appearing in an affine expression.
244 class AffineSymbolExpr : public AffineExpr {
245 public:
246 using ImplType = detail::AffineDimExprStorage;
247 /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
248 unsigned getPosition() const;
249 };
250
251 /// An integer constant appearing in affine expression.
252 class AffineConstantExpr : public AffineExpr {
253 public:
254 using ImplType = detail::AffineConstantExprStorage;
255 /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
256 int64_t getValue() const;
257 };
258
259 /// Make AffineExpr hashable.
hash_value(AffineExpr arg)260 inline ::llvm::hash_code hash_value(AffineExpr arg) {
261 return ::llvm::hash_value(arg.expr);
262 }
263
264 inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; }
265 inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; }
266 inline AffineExpr operator-(int64_t val, AffineExpr expr) {
267 return expr * (-1) + val;
268 }
269
270 /// These free functions allow clients of the API to not use classes in detail.
271 AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
272 AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
273 AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
274 SmallVector<AffineExpr> getAffineConstantExprs(ArrayRef<int64_t> constants,
275 MLIRContext *context);
276 AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
277 AffineExpr rhs);
278
279 /// Constructs an affine expression from a flat ArrayRef. If there are local
280 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
281 /// products expression, 'localExprs' is expected to have the AffineExpr
282 /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
283 /// format [dims, symbols, locals, constant term].
284 AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
285 unsigned numDims, unsigned numSymbols,
286 ArrayRef<AffineExpr> localExprs,
287 MLIRContext *context);
288
289 raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
290
291 template <typename U>
isa()292 constexpr bool AffineExpr::isa() const {
293 if constexpr (std::is_same_v<U, AffineBinaryOpExpr>)
294 return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
295 if constexpr (std::is_same_v<U, AffineDimExpr>)
296 return getKind() == AffineExprKind::DimId;
297 if constexpr (std::is_same_v<U, AffineSymbolExpr>)
298 return getKind() == AffineExprKind::SymbolId;
299 if constexpr (std::is_same_v<U, AffineConstantExpr>)
300 return getKind() == AffineExprKind::Constant;
301 }
302 template <typename U>
dyn_cast()303 U AffineExpr::dyn_cast() const {
304 return llvm::dyn_cast<U>(*this);
305 }
306 template <typename U>
dyn_cast_or_null()307 U AffineExpr::dyn_cast_or_null() const {
308 return llvm::dyn_cast_or_null<U>(*this);
309 }
310 template <typename U>
cast()311 U AffineExpr::cast() const {
312 return llvm::cast<U>(*this);
313 }
314
315 /// Simplify an affine expression by flattening and some amount of simple
316 /// analysis. This has complexity linear in the number of nodes in 'expr'.
317 /// Returns the simplified expression, which is the same as the input expression
318 /// if it can't be simplified. When `expr` is semi-affine, a simplified
319 /// semi-affine expression is constructed in the sorted order of dimension and
320 /// symbol positions.
321 AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
322 unsigned numSymbols);
323
324 namespace detail {
325 template <int N>
bindDims(MLIRContext * ctx)326 void bindDims(MLIRContext *ctx) {}
327
328 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindDims(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)329 void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
330 e = getAffineDimExpr(N, ctx);
331 bindDims<N + 1, AffineExprTy2 &...>(ctx, exprs...);
332 }
333
334 template <int N>
bindSymbols(MLIRContext * ctx)335 void bindSymbols(MLIRContext *ctx) {}
336
337 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindSymbols(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)338 void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
339 e = getAffineSymbolExpr(N, ctx);
340 bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
341 }
342
343 } // namespace detail
344
345 /// Bind a list of AffineExpr references to DimExpr at positions:
346 /// [0 .. sizeof...(exprs)]
347 template <typename... AffineExprTy>
bindDims(MLIRContext * ctx,AffineExprTy &...exprs)348 void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
349 detail::bindDims<0>(ctx, exprs...);
350 }
351
352 template <typename AffineExprTy>
bindDimsList(MLIRContext * ctx,MutableArrayRef<AffineExprTy> exprs)353 void bindDimsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
354 int idx = 0;
355 for (AffineExprTy &e : exprs)
356 e = getAffineDimExpr(idx++, ctx);
357 }
358
359 /// Bind a list of AffineExpr references to SymbolExpr at positions:
360 /// [0 .. sizeof...(exprs)]
361 template <typename... AffineExprTy>
bindSymbols(MLIRContext * ctx,AffineExprTy &...exprs)362 void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
363 detail::bindSymbols<0>(ctx, exprs...);
364 }
365
366 template <typename AffineExprTy>
bindSymbolsList(MLIRContext * ctx,MutableArrayRef<AffineExprTy> exprs)367 void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
368 int idx = 0;
369 for (AffineExprTy &e : exprs)
370 e = getAffineSymbolExpr(idx++, ctx);
371 }
372
373 /// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
374 /// the constant lower and upper bounds for its inputs provided in
375 /// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
376 /// bound can't be computed. This method only handles simple sum of product
377 /// expressions (w.r.t constant coefficients) so as to not depend on anything
378 /// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
379 /// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
380 /// semi-affine ones will lead a none being returned.
381 std::optional<int64_t>
382 getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
383 ArrayRef<std::optional<int64_t>> constLowerBounds,
384 ArrayRef<std::optional<int64_t>> constUpperBounds,
385 bool isUpper);
386
387 } // namespace mlir
388
389 namespace llvm {
390
391 // AffineExpr hash just like pointers
392 template <>
393 struct DenseMapInfo<mlir::AffineExpr> {
394 static mlir::AffineExpr getEmptyKey() {
395 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
396 return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
397 }
398 static mlir::AffineExpr getTombstoneKey() {
399 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
400 return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
401 }
402 static unsigned getHashValue(mlir::AffineExpr val) {
403 return mlir::hash_value(val);
404 }
405 static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) {
406 return LHS == RHS;
407 }
408 };
409
410 /// Add support for llvm style casts. We provide a cast between To and From if
411 /// From is mlir::AffineExpr or derives from it.
412 template <typename To, typename From>
413 struct CastInfo<To, From,
414 std::enable_if_t<std::is_same_v<mlir::AffineExpr,
415 std::remove_const_t<From>> ||
416 std::is_base_of_v<mlir::AffineExpr, From>>>
417 : NullableValueCastFailed<To>,
418 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
419
420 static inline bool isPossible(mlir::AffineExpr expr) {
421 /// Return a constant true instead of a dynamic true when casting to self or
422 /// up the hierarchy.
423 if constexpr (std::is_base_of_v<To, From>) {
424 return true;
425 } else {
426 if constexpr (std::is_same_v<To, ::mlir::AffineBinaryOpExpr>)
427 return expr.getKind() <= ::mlir::AffineExprKind::LAST_AFFINE_BINARY_OP;
428 if constexpr (std::is_same_v<To, ::mlir::AffineDimExpr>)
429 return expr.getKind() == ::mlir::AffineExprKind::DimId;
430 if constexpr (std::is_same_v<To, ::mlir::AffineSymbolExpr>)
431 return expr.getKind() == ::mlir::AffineExprKind::SymbolId;
432 if constexpr (std::is_same_v<To, ::mlir::AffineConstantExpr>)
433 return expr.getKind() == ::mlir::AffineExprKind::Constant;
434 }
435 }
436 static inline To doCast(mlir::AffineExpr expr) { return To(expr.getImpl()); }
437 };
438
439 } // namespace llvm
440
441 #endif // MLIR_IR_AFFINEEXPR_H
442