xref: /llvm-project/mlir/include/mlir/IR/AffineExprVisitor.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the AffineExpr visitor class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_AFFINEEXPRVISITOR_H
14 #define MLIR_IR_AFFINEEXPRVISITOR_H
15 
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/ArrayRef.h"
19 
20 namespace mlir {
21 
22 /// Base class for AffineExpr visitors/walkers.
23 ///
24 /// AffineExpr visitors are used when you want to perform different actions
25 /// for different kinds of AffineExprs without having to use lots of casts
26 /// and a big switch instruction.
27 ///
28 /// To define your own visitor, inherit from this class, specifying your
29 /// new type for the 'SubClass' template parameter, and "override" visitXXX
30 /// functions in your class. This class is defined in terms of statically
31 /// resolved overloading, not virtual functions.
32 ///
33 /// The visitor is templated on its return type (`RetTy`). With a WalkResult
34 /// return type, the visitor supports interrupting walks.
35 ///
36 /// For example, here is a visitor that counts the number of for AffineDimExprs
37 /// in an AffineExpr.
38 ///
39 ///  /// Declare the class.  Note that we derive from AffineExprVisitor
40 ///  /// instantiated with our new subclasses_ type.
41 ///
42 ///  struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
43 ///    unsigned numDimExprs;
44 ///    DimExprCounter() : numDimExprs(0) {}
45 ///    void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
46 ///  };
47 ///
48 ///  And this class would be used like this:
49 ///    DimExprCounter dec;
50 ///    dec.visit(affineExpr);
51 ///    numDimExprs = dec.numDimExprs;
52 ///
53 /// AffineExprVisitor provides visit methods for the following binary affine
54 /// op expressions:
55 /// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
56 /// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
57 /// AffineBinaryCeilDivOpExpr. Note that default implementations of these
58 /// methods will call the general AffineBinaryOpExpr method.
59 ///
60 /// In addition, visit methods are provided for the following affine
61 //  expressions: AffineConstantExpr, AffineDimExpr, and
62 //  AffineSymbolExpr.
63 ///
64 /// Note that if you don't implement visitXXX for some affine expression type,
65 /// the visitXXX method for Instruction superclass will be invoked.
66 ///
67 /// Note that this class is specifically designed as a template to avoid
68 /// virtual function call overhead. Defining and using a AffineExprVisitor is
69 /// just as efficient as having your own switch instruction over the instruction
70 /// opcode.
71 template <typename SubClass, typename RetTy>
72 class AffineExprVisitorBase {
73 public:
74   // Function to visit an AffineExpr.
visit(AffineExpr expr)75   RetTy visit(AffineExpr expr) {
76     static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
77                   "Must instantiate with a derived type of AffineExprVisitor");
78     auto self = static_cast<SubClass *>(this);
79     switch (expr.getKind()) {
80     case AffineExprKind::Add: {
81       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
82       return self->visitAddExpr(binOpExpr);
83     }
84     case AffineExprKind::Mul: {
85       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
86       return self->visitMulExpr(binOpExpr);
87     }
88     case AffineExprKind::Mod: {
89       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
90       return self->visitModExpr(binOpExpr);
91     }
92     case AffineExprKind::FloorDiv: {
93       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
94       return self->visitFloorDivExpr(binOpExpr);
95     }
96     case AffineExprKind::CeilDiv: {
97       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
98       return self->visitCeilDivExpr(binOpExpr);
99     }
100     case AffineExprKind::Constant:
101       return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
102     case AffineExprKind::DimId:
103       return self->visitDimExpr(cast<AffineDimExpr>(expr));
104     case AffineExprKind::SymbolId:
105       return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
106     }
107     llvm_unreachable("Unknown AffineExpr");
108   }
109 
110   //===--------------------------------------------------------------------===//
111   // Visitation functions... these functions provide default fallbacks in case
112   // the user does not specify what to do for a particular instruction type.
113   // The default behavior is to generalize the instruction type to its subtype
114   // and try visiting the subtype.  All of this should be inlined perfectly,
115   // because there are no virtual functions to get in the way.
116   //
117 
118   // Default visit methods. Note that the default op-specific binary op visit
119   // methods call the general visitAffineBinaryOpExpr visit method.
visitAffineBinaryOpExpr(AffineBinaryOpExpr expr)120   RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
visitAddExpr(AffineBinaryOpExpr expr)121   RetTy visitAddExpr(AffineBinaryOpExpr expr) {
122     return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
123   }
visitMulExpr(AffineBinaryOpExpr expr)124   RetTy visitMulExpr(AffineBinaryOpExpr expr) {
125     return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
126   }
visitModExpr(AffineBinaryOpExpr expr)127   RetTy visitModExpr(AffineBinaryOpExpr expr) {
128     return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
129   }
visitFloorDivExpr(AffineBinaryOpExpr expr)130   RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
131     return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
132   }
visitCeilDivExpr(AffineBinaryOpExpr expr)133   RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
134     return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
135   }
visitConstantExpr(AffineConstantExpr expr)136   RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
visitDimExpr(AffineDimExpr expr)137   RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
visitSymbolExpr(AffineSymbolExpr expr)138   RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
139 };
140 
141 /// See documentation for AffineExprVisitorBase. This visitor supports
142 /// interrupting walks when a `WalkResult` is used for `RetTy`.
143 template <typename SubClass, typename RetTy = void>
144 class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
145   //===--------------------------------------------------------------------===//
146   // Interface code - This is the public interface of the AffineExprVisitor
147   // that you use to visit affine expressions...
148 public:
149   // Function to walk an AffineExpr (in post order).
walkPostOrder(AffineExpr expr)150   RetTy walkPostOrder(AffineExpr expr) {
151     static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
152                   "Must instantiate with a derived type of AffineExprVisitor");
153     auto self = static_cast<SubClass *>(this);
154     switch (expr.getKind()) {
155     case AffineExprKind::Add: {
156       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
157       if constexpr (std::is_same<RetTy, WalkResult>::value) {
158         if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
159           return WalkResult::interrupt();
160       } else {
161         walkOperandsPostOrder(binOpExpr);
162       }
163       return self->visitAddExpr(binOpExpr);
164     }
165     case AffineExprKind::Mul: {
166       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
167       if constexpr (std::is_same<RetTy, WalkResult>::value) {
168         if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
169           return WalkResult::interrupt();
170       } else {
171         walkOperandsPostOrder(binOpExpr);
172       }
173       return self->visitMulExpr(binOpExpr);
174     }
175     case AffineExprKind::Mod: {
176       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
177       if constexpr (std::is_same<RetTy, WalkResult>::value) {
178         if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
179           return WalkResult::interrupt();
180       } else {
181         walkOperandsPostOrder(binOpExpr);
182       }
183       return self->visitModExpr(binOpExpr);
184     }
185     case AffineExprKind::FloorDiv: {
186       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
187       if constexpr (std::is_same<RetTy, WalkResult>::value) {
188         if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
189           return WalkResult::interrupt();
190       } else {
191         walkOperandsPostOrder(binOpExpr);
192       }
193       return self->visitFloorDivExpr(binOpExpr);
194     }
195     case AffineExprKind::CeilDiv: {
196       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
197       if constexpr (std::is_same<RetTy, WalkResult>::value) {
198         if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
199           return WalkResult::interrupt();
200       } else {
201         walkOperandsPostOrder(binOpExpr);
202       }
203       return self->visitCeilDivExpr(binOpExpr);
204     }
205     case AffineExprKind::Constant:
206       return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
207     case AffineExprKind::DimId:
208       return self->visitDimExpr(cast<AffineDimExpr>(expr));
209     case AffineExprKind::SymbolId:
210       return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
211     }
212     llvm_unreachable("Unknown AffineExpr");
213   }
214 
215 private:
216   // Walk the operands - each operand is itself walked in post order.
walkOperandsPostOrder(AffineBinaryOpExpr expr)217   RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
218     if constexpr (std::is_same<RetTy, WalkResult>::value) {
219       if (walkPostOrder(expr.getLHS()).wasInterrupted())
220         return WalkResult::interrupt();
221     } else {
222       walkPostOrder(expr.getLHS());
223     }
224     if constexpr (std::is_same<RetTy, WalkResult>::value) {
225       if (walkPostOrder(expr.getRHS()).wasInterrupted())
226         return WalkResult::interrupt();
227       return WalkResult::advance();
228     } else {
229       return walkPostOrder(expr.getRHS());
230     }
231   }
232 };
233 
234 template <typename SubClass>
235 class AffineExprVisitor<SubClass, LogicalResult>
236     : public AffineExprVisitorBase<SubClass, LogicalResult> {
237   //===--------------------------------------------------------------------===//
238   // Interface code - This is the public interface of the AffineExprVisitor
239   // that you use to visit affine expressions...
240 public:
241   // Function to walk an AffineExpr (in post order).
walkPostOrder(AffineExpr expr)242   LogicalResult walkPostOrder(AffineExpr expr) {
243     static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
244                   "Must instantiate with a derived type of AffineExprVisitor");
245     auto self = static_cast<SubClass *>(this);
246     switch (expr.getKind()) {
247     case AffineExprKind::Add: {
248       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
249       if (failed(walkOperandsPostOrder(binOpExpr)))
250         return failure();
251       return self->visitAddExpr(binOpExpr);
252     }
253     case AffineExprKind::Mul: {
254       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
255       if (failed(walkOperandsPostOrder(binOpExpr)))
256         return failure();
257       return self->visitMulExpr(binOpExpr);
258     }
259     case AffineExprKind::Mod: {
260       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
261       if (failed(walkOperandsPostOrder(binOpExpr)))
262         return failure();
263       return self->visitModExpr(binOpExpr);
264     }
265     case AffineExprKind::FloorDiv: {
266       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
267       if (failed(walkOperandsPostOrder(binOpExpr)))
268         return failure();
269       return self->visitFloorDivExpr(binOpExpr);
270     }
271     case AffineExprKind::CeilDiv: {
272       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
273       if (failed(walkOperandsPostOrder(binOpExpr)))
274         return failure();
275       return self->visitCeilDivExpr(binOpExpr);
276     }
277     case AffineExprKind::Constant:
278       return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
279     case AffineExprKind::DimId:
280       return self->visitDimExpr(cast<AffineDimExpr>(expr));
281     case AffineExprKind::SymbolId:
282       return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
283     }
284     llvm_unreachable("Unknown AffineExpr");
285   }
286 
287 private:
288   // Walk the operands - each operand is itself walked in post order.
walkOperandsPostOrder(AffineBinaryOpExpr expr)289   LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
290     if (failed(walkPostOrder(expr.getLHS())))
291       return failure();
292     if (failed(walkPostOrder(expr.getRHS())))
293       return failure();
294     return success();
295   }
296 };
297 
298 // This class is used to flatten a pure affine expression (AffineExpr,
299 // which is in a tree form) into a sum of products (w.r.t constants) when
300 // possible, and in that process simplifying the expression. For a modulo,
301 // floordiv, or a ceildiv expression, an additional identifier, called a local
302 // identifier, is introduced to rewrite the expression as a sum of product
303 // affine expression. Each local identifier is always and by construction a
304 // floordiv of a pure add/mul affine function of dimensional, symbolic, and
305 // other local identifiers, in a non-mutually recursive way. Hence, every local
306 // identifier can ultimately always be recovered as an affine function of
307 // dimensional and symbolic identifiers (involving floordiv's); note however
308 // that by AffineExpr construction, some floordiv combinations are converted to
309 // mod's. The result of the flattening is a flattened expression and a set of
310 // constraints involving just the local variables.
311 //
312 // d2 + (d0 + d1) floordiv 4  is flattened to d2 + q where 'q' is the local
313 // variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
314 //
315 // The simplification performed includes the accumulation of contributions for
316 // each dimensional and symbolic identifier together, the simplification of
317 // floordiv/ceildiv/mod expressions and other simplifications that in turn
318 // happen as a result. A simplification that this flattening naturally performs
319 // is of simplifying the numerator and denominator of floordiv/ceildiv, and
320 // folding a modulo expression to a zero, if possible. Three examples are below:
321 //
322 // (d0 + 3 * d1) + d0) - 2 * d1) - d0    simplified to     d0 + d1
323 // (d0 - d0 mod 4 + 4) mod 4             simplified to     0
324 // (3*d0 + 2*d1 + d0) floordiv 2 + d1    simplified to     2*d0 + 2*d1
325 //
326 // The way the flattening works for the second example is as follows: d0 % 4 is
327 // replaced by d0 - 4*q with q being introduced: the expression then simplifies
328 // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
329 // zero. Note that an affine expression may not always be expressible purely as
330 // a sum of products involving just the original dimensional and symbolic
331 // identifiers due to the presence of modulo/floordiv/ceildiv expressions that
332 // may not be eliminated after simplification; in such cases, the final
333 // expression can be reconstructed by replacing the local identifiers with their
334 // corresponding explicit form stored in 'localExprs' (note that each of the
335 // explicit forms itself would have been simplified).
336 //
337 // The expression walk method here performs a linear time post order walk that
338 // performs the above simplifications through visit methods, with partial
339 // results being stored in 'operandExprStack'. When a parent expr is visited,
340 // the flattened expressions corresponding to its two operands would already be
341 // on the stack - the parent expression looks at the two flattened expressions
342 // and combines the two. It pops off the operand expressions and pushes the
343 // combined result (although this is done in-place on its LHS operand expr).
344 // When the walk is completed, the flattened form of the top-level expression
345 // would be left on the stack.
346 //
347 // A flattener can be repeatedly used for multiple affine expressions that bind
348 // to the same operands, for example, for all result expressions of an
349 // AffineMap or AffineValueMap. In such cases, using it for multiple expressions
350 // is more efficient than creating a new flattener for each expression since
351 // common identical div and mod expressions appearing across different
352 // expressions are mapped to the same local identifier (same column position in
353 // 'localVarCst').
354 class SimpleAffineExprFlattener
355     : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
356 public:
357   // Flattend expression layout: [dims, symbols, locals, constant]
358   // Stack that holds the LHS and RHS operands while visiting a binary op expr.
359   // In future, consider adding a prepass to determine how big the SmallVector's
360   // will be, and linearize this to std::vector<int64_t> to prevent
361   // SmallVector moves on re-allocation.
362   std::vector<SmallVector<int64_t, 8>> operandExprStack;
363 
364   unsigned numDims;
365   unsigned numSymbols;
366 
367   // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
368   unsigned numLocals;
369 
370   // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
371   // which new identifiers were introduced; if the latter do not get canceled
372   // out, these expressions can be readily used to reconstruct the AffineExpr
373   // (tree) form. Note that these expressions themselves would have been
374   // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
375   // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
376   // ceildiv 2 would be the local expression stored for q.
377   SmallVector<AffineExpr, 4> localExprs;
378 
379   SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
380 
381   virtual ~SimpleAffineExprFlattener() = default;
382 
383   // Visitor method overrides.
384   LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
385   LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
386   LogicalResult visitDimExpr(AffineDimExpr expr);
387   LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
388   LogicalResult visitConstantExpr(AffineConstantExpr expr);
389   LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
390   LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
391 
392   //
393   // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
394   //
395   // A mod expression "expr mod c" is thus flattened by introducing a new local
396   // variable q (= expr floordiv c), such that expr mod c is replaced with
397   // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
398   LogicalResult visitModExpr(AffineBinaryOpExpr expr);
399 
400 protected:
401   // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
402   // The local identifier added is always a floordiv of a pure add/mul affine
403   // function of other identifiers, coefficients of which are specified in
404   // dividend and with respect to a positive constant divisor. localExpr is the
405   // simplified tree expression (AffineExpr) corresponding to the quantifier.
406   virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
407                                   AffineExpr localExpr);
408 
409   /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
410   /// expr) when the rhs is a symbolic expression. The local identifier added
411   /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
412   /// function of other identifiers, coefficients of which are specified in the
413   /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
414   /// symbolic rhs expression. `localExpr` is the simplified tree expression
415   /// (AffineExpr) corresponding to the quantifier.
416   virtual LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
417                                              ArrayRef<int64_t> rhs,
418                                              AffineExpr localExpr);
419 
420 private:
421   /// Adds `localExpr`, which may be mod, ceildiv, floordiv or mod expression
422   /// representing the affine expression corresponding to the quantifier
423   /// introduced as the local variable corresponding to `localExpr`. If the
424   /// quantifier is already present, we put the coefficient in the proper index
425   /// of `result`, otherwise we add a new local variable and put the coefficient
426   /// there.
427   LogicalResult addLocalVariableSemiAffine(ArrayRef<int64_t> lhs,
428                                            ArrayRef<int64_t> rhs,
429                                            AffineExpr localExpr,
430                                            SmallVectorImpl<int64_t> &result,
431                                            unsigned long resultSize);
432 
433   // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
434   // A floordiv is thus flattened by introducing a new local variable q, and
435   // replacing that expression with 'q' while adding the constraints
436   // c * q <= expr <= c * q + c - 1 to localVarCst (done by
437   // IntegerRelation::addLocalFloorDiv).
438   //
439   // A ceildiv is similarly flattened:
440   // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
441   LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
442 
443   int findLocalId(AffineExpr localExpr);
444 
getNumCols()445   inline unsigned getNumCols() const {
446     return numDims + numSymbols + numLocals + 1;
447   }
getConstantIndex()448   inline unsigned getConstantIndex() const { return getNumCols() - 1; }
getLocalVarStartIndex()449   inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
getSymbolStartIndex()450   inline unsigned getSymbolStartIndex() const { return numDims; }
getDimStartIndex()451   inline unsigned getDimStartIndex() const { return 0; }
452 };
453 
454 } // namespace mlir
455 
456 #endif // MLIR_IR_AFFINEEXPRVISITOR_H
457