1 //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the AffineExpr visitor class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_AFFINEEXPRVISITOR_H 14 #define MLIR_IR_AFFINEEXPRVISITOR_H 15 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/Support/LLVM.h" 18 #include "llvm/ADT/ArrayRef.h" 19 20 namespace mlir { 21 22 /// Base class for AffineExpr visitors/walkers. 23 /// 24 /// AffineExpr visitors are used when you want to perform different actions 25 /// for different kinds of AffineExprs without having to use lots of casts 26 /// and a big switch instruction. 27 /// 28 /// To define your own visitor, inherit from this class, specifying your 29 /// new type for the 'SubClass' template parameter, and "override" visitXXX 30 /// functions in your class. This class is defined in terms of statically 31 /// resolved overloading, not virtual functions. 32 /// 33 /// The visitor is templated on its return type (`RetTy`). With a WalkResult 34 /// return type, the visitor supports interrupting walks. 35 /// 36 /// For example, here is a visitor that counts the number of for AffineDimExprs 37 /// in an AffineExpr. 38 /// 39 /// /// Declare the class. Note that we derive from AffineExprVisitor 40 /// /// instantiated with our new subclasses_ type. 41 /// 42 /// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> { 43 /// unsigned numDimExprs; 44 /// DimExprCounter() : numDimExprs(0) {} 45 /// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; } 46 /// }; 47 /// 48 /// And this class would be used like this: 49 /// DimExprCounter dec; 50 /// dec.visit(affineExpr); 51 /// numDimExprs = dec.numDimExprs; 52 /// 53 /// AffineExprVisitor provides visit methods for the following binary affine 54 /// op expressions: 55 /// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr, 56 /// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr, 57 /// AffineBinaryCeilDivOpExpr. Note that default implementations of these 58 /// methods will call the general AffineBinaryOpExpr method. 59 /// 60 /// In addition, visit methods are provided for the following affine 61 // expressions: AffineConstantExpr, AffineDimExpr, and 62 // AffineSymbolExpr. 63 /// 64 /// Note that if you don't implement visitXXX for some affine expression type, 65 /// the visitXXX method for Instruction superclass will be invoked. 66 /// 67 /// Note that this class is specifically designed as a template to avoid 68 /// virtual function call overhead. Defining and using a AffineExprVisitor is 69 /// just as efficient as having your own switch instruction over the instruction 70 /// opcode. 71 template <typename SubClass, typename RetTy> 72 class AffineExprVisitorBase { 73 public: 74 // Function to visit an AffineExpr. visit(AffineExpr expr)75 RetTy visit(AffineExpr expr) { 76 static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value, 77 "Must instantiate with a derived type of AffineExprVisitor"); 78 auto self = static_cast<SubClass *>(this); 79 switch (expr.getKind()) { 80 case AffineExprKind::Add: { 81 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 82 return self->visitAddExpr(binOpExpr); 83 } 84 case AffineExprKind::Mul: { 85 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 86 return self->visitMulExpr(binOpExpr); 87 } 88 case AffineExprKind::Mod: { 89 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 90 return self->visitModExpr(binOpExpr); 91 } 92 case AffineExprKind::FloorDiv: { 93 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 94 return self->visitFloorDivExpr(binOpExpr); 95 } 96 case AffineExprKind::CeilDiv: { 97 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 98 return self->visitCeilDivExpr(binOpExpr); 99 } 100 case AffineExprKind::Constant: 101 return self->visitConstantExpr(cast<AffineConstantExpr>(expr)); 102 case AffineExprKind::DimId: 103 return self->visitDimExpr(cast<AffineDimExpr>(expr)); 104 case AffineExprKind::SymbolId: 105 return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr)); 106 } 107 llvm_unreachable("Unknown AffineExpr"); 108 } 109 110 //===--------------------------------------------------------------------===// 111 // Visitation functions... these functions provide default fallbacks in case 112 // the user does not specify what to do for a particular instruction type. 113 // The default behavior is to generalize the instruction type to its subtype 114 // and try visiting the subtype. All of this should be inlined perfectly, 115 // because there are no virtual functions to get in the way. 116 // 117 118 // Default visit methods. Note that the default op-specific binary op visit 119 // methods call the general visitAffineBinaryOpExpr visit method. visitAffineBinaryOpExpr(AffineBinaryOpExpr expr)120 RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); } visitAddExpr(AffineBinaryOpExpr expr)121 RetTy visitAddExpr(AffineBinaryOpExpr expr) { 122 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 123 } visitMulExpr(AffineBinaryOpExpr expr)124 RetTy visitMulExpr(AffineBinaryOpExpr expr) { 125 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 126 } visitModExpr(AffineBinaryOpExpr expr)127 RetTy visitModExpr(AffineBinaryOpExpr expr) { 128 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 129 } visitFloorDivExpr(AffineBinaryOpExpr expr)130 RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) { 131 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 132 } visitCeilDivExpr(AffineBinaryOpExpr expr)133 RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) { 134 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 135 } visitConstantExpr(AffineConstantExpr expr)136 RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); } visitDimExpr(AffineDimExpr expr)137 RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); } visitSymbolExpr(AffineSymbolExpr expr)138 RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); } 139 }; 140 141 /// See documentation for AffineExprVisitorBase. This visitor supports 142 /// interrupting walks when a `WalkResult` is used for `RetTy`. 143 template <typename SubClass, typename RetTy = void> 144 class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> { 145 //===--------------------------------------------------------------------===// 146 // Interface code - This is the public interface of the AffineExprVisitor 147 // that you use to visit affine expressions... 148 public: 149 // Function to walk an AffineExpr (in post order). walkPostOrder(AffineExpr expr)150 RetTy walkPostOrder(AffineExpr expr) { 151 static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value, 152 "Must instantiate with a derived type of AffineExprVisitor"); 153 auto self = static_cast<SubClass *>(this); 154 switch (expr.getKind()) { 155 case AffineExprKind::Add: { 156 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 157 if constexpr (std::is_same<RetTy, WalkResult>::value) { 158 if (walkOperandsPostOrder(binOpExpr).wasInterrupted()) 159 return WalkResult::interrupt(); 160 } else { 161 walkOperandsPostOrder(binOpExpr); 162 } 163 return self->visitAddExpr(binOpExpr); 164 } 165 case AffineExprKind::Mul: { 166 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 167 if constexpr (std::is_same<RetTy, WalkResult>::value) { 168 if (walkOperandsPostOrder(binOpExpr).wasInterrupted()) 169 return WalkResult::interrupt(); 170 } else { 171 walkOperandsPostOrder(binOpExpr); 172 } 173 return self->visitMulExpr(binOpExpr); 174 } 175 case AffineExprKind::Mod: { 176 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 177 if constexpr (std::is_same<RetTy, WalkResult>::value) { 178 if (walkOperandsPostOrder(binOpExpr).wasInterrupted()) 179 return WalkResult::interrupt(); 180 } else { 181 walkOperandsPostOrder(binOpExpr); 182 } 183 return self->visitModExpr(binOpExpr); 184 } 185 case AffineExprKind::FloorDiv: { 186 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 187 if constexpr (std::is_same<RetTy, WalkResult>::value) { 188 if (walkOperandsPostOrder(binOpExpr).wasInterrupted()) 189 return WalkResult::interrupt(); 190 } else { 191 walkOperandsPostOrder(binOpExpr); 192 } 193 return self->visitFloorDivExpr(binOpExpr); 194 } 195 case AffineExprKind::CeilDiv: { 196 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 197 if constexpr (std::is_same<RetTy, WalkResult>::value) { 198 if (walkOperandsPostOrder(binOpExpr).wasInterrupted()) 199 return WalkResult::interrupt(); 200 } else { 201 walkOperandsPostOrder(binOpExpr); 202 } 203 return self->visitCeilDivExpr(binOpExpr); 204 } 205 case AffineExprKind::Constant: 206 return self->visitConstantExpr(cast<AffineConstantExpr>(expr)); 207 case AffineExprKind::DimId: 208 return self->visitDimExpr(cast<AffineDimExpr>(expr)); 209 case AffineExprKind::SymbolId: 210 return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr)); 211 } 212 llvm_unreachable("Unknown AffineExpr"); 213 } 214 215 private: 216 // Walk the operands - each operand is itself walked in post order. walkOperandsPostOrder(AffineBinaryOpExpr expr)217 RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) { 218 if constexpr (std::is_same<RetTy, WalkResult>::value) { 219 if (walkPostOrder(expr.getLHS()).wasInterrupted()) 220 return WalkResult::interrupt(); 221 } else { 222 walkPostOrder(expr.getLHS()); 223 } 224 if constexpr (std::is_same<RetTy, WalkResult>::value) { 225 if (walkPostOrder(expr.getRHS()).wasInterrupted()) 226 return WalkResult::interrupt(); 227 return WalkResult::advance(); 228 } else { 229 return walkPostOrder(expr.getRHS()); 230 } 231 } 232 }; 233 234 template <typename SubClass> 235 class AffineExprVisitor<SubClass, LogicalResult> 236 : public AffineExprVisitorBase<SubClass, LogicalResult> { 237 //===--------------------------------------------------------------------===// 238 // Interface code - This is the public interface of the AffineExprVisitor 239 // that you use to visit affine expressions... 240 public: 241 // Function to walk an AffineExpr (in post order). walkPostOrder(AffineExpr expr)242 LogicalResult walkPostOrder(AffineExpr expr) { 243 static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value, 244 "Must instantiate with a derived type of AffineExprVisitor"); 245 auto self = static_cast<SubClass *>(this); 246 switch (expr.getKind()) { 247 case AffineExprKind::Add: { 248 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 249 if (failed(walkOperandsPostOrder(binOpExpr))) 250 return failure(); 251 return self->visitAddExpr(binOpExpr); 252 } 253 case AffineExprKind::Mul: { 254 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 255 if (failed(walkOperandsPostOrder(binOpExpr))) 256 return failure(); 257 return self->visitMulExpr(binOpExpr); 258 } 259 case AffineExprKind::Mod: { 260 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 261 if (failed(walkOperandsPostOrder(binOpExpr))) 262 return failure(); 263 return self->visitModExpr(binOpExpr); 264 } 265 case AffineExprKind::FloorDiv: { 266 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 267 if (failed(walkOperandsPostOrder(binOpExpr))) 268 return failure(); 269 return self->visitFloorDivExpr(binOpExpr); 270 } 271 case AffineExprKind::CeilDiv: { 272 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 273 if (failed(walkOperandsPostOrder(binOpExpr))) 274 return failure(); 275 return self->visitCeilDivExpr(binOpExpr); 276 } 277 case AffineExprKind::Constant: 278 return self->visitConstantExpr(cast<AffineConstantExpr>(expr)); 279 case AffineExprKind::DimId: 280 return self->visitDimExpr(cast<AffineDimExpr>(expr)); 281 case AffineExprKind::SymbolId: 282 return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr)); 283 } 284 llvm_unreachable("Unknown AffineExpr"); 285 } 286 287 private: 288 // Walk the operands - each operand is itself walked in post order. walkOperandsPostOrder(AffineBinaryOpExpr expr)289 LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) { 290 if (failed(walkPostOrder(expr.getLHS()))) 291 return failure(); 292 if (failed(walkPostOrder(expr.getRHS()))) 293 return failure(); 294 return success(); 295 } 296 }; 297 298 // This class is used to flatten a pure affine expression (AffineExpr, 299 // which is in a tree form) into a sum of products (w.r.t constants) when 300 // possible, and in that process simplifying the expression. For a modulo, 301 // floordiv, or a ceildiv expression, an additional identifier, called a local 302 // identifier, is introduced to rewrite the expression as a sum of product 303 // affine expression. Each local identifier is always and by construction a 304 // floordiv of a pure add/mul affine function of dimensional, symbolic, and 305 // other local identifiers, in a non-mutually recursive way. Hence, every local 306 // identifier can ultimately always be recovered as an affine function of 307 // dimensional and symbolic identifiers (involving floordiv's); note however 308 // that by AffineExpr construction, some floordiv combinations are converted to 309 // mod's. The result of the flattening is a flattened expression and a set of 310 // constraints involving just the local variables. 311 // 312 // d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local 313 // variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3. 314 // 315 // The simplification performed includes the accumulation of contributions for 316 // each dimensional and symbolic identifier together, the simplification of 317 // floordiv/ceildiv/mod expressions and other simplifications that in turn 318 // happen as a result. A simplification that this flattening naturally performs 319 // is of simplifying the numerator and denominator of floordiv/ceildiv, and 320 // folding a modulo expression to a zero, if possible. Three examples are below: 321 // 322 // (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 323 // (d0 - d0 mod 4 + 4) mod 4 simplified to 0 324 // (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 325 // 326 // The way the flattening works for the second example is as follows: d0 % 4 is 327 // replaced by d0 - 4*q with q being introduced: the expression then simplifies 328 // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to 329 // zero. Note that an affine expression may not always be expressible purely as 330 // a sum of products involving just the original dimensional and symbolic 331 // identifiers due to the presence of modulo/floordiv/ceildiv expressions that 332 // may not be eliminated after simplification; in such cases, the final 333 // expression can be reconstructed by replacing the local identifiers with their 334 // corresponding explicit form stored in 'localExprs' (note that each of the 335 // explicit forms itself would have been simplified). 336 // 337 // The expression walk method here performs a linear time post order walk that 338 // performs the above simplifications through visit methods, with partial 339 // results being stored in 'operandExprStack'. When a parent expr is visited, 340 // the flattened expressions corresponding to its two operands would already be 341 // on the stack - the parent expression looks at the two flattened expressions 342 // and combines the two. It pops off the operand expressions and pushes the 343 // combined result (although this is done in-place on its LHS operand expr). 344 // When the walk is completed, the flattened form of the top-level expression 345 // would be left on the stack. 346 // 347 // A flattener can be repeatedly used for multiple affine expressions that bind 348 // to the same operands, for example, for all result expressions of an 349 // AffineMap or AffineValueMap. In such cases, using it for multiple expressions 350 // is more efficient than creating a new flattener for each expression since 351 // common identical div and mod expressions appearing across different 352 // expressions are mapped to the same local identifier (same column position in 353 // 'localVarCst'). 354 class SimpleAffineExprFlattener 355 : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> { 356 public: 357 // Flattend expression layout: [dims, symbols, locals, constant] 358 // Stack that holds the LHS and RHS operands while visiting a binary op expr. 359 // In future, consider adding a prepass to determine how big the SmallVector's 360 // will be, and linearize this to std::vector<int64_t> to prevent 361 // SmallVector moves on re-allocation. 362 std::vector<SmallVector<int64_t, 8>> operandExprStack; 363 364 unsigned numDims; 365 unsigned numSymbols; 366 367 // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's. 368 unsigned numLocals; 369 370 // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for 371 // which new identifiers were introduced; if the latter do not get canceled 372 // out, these expressions can be readily used to reconstruct the AffineExpr 373 // (tree) form. Note that these expressions themselves would have been 374 // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 375 // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) 376 // ceildiv 2 would be the local expression stored for q. 377 SmallVector<AffineExpr, 4> localExprs; 378 379 SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols); 380 381 virtual ~SimpleAffineExprFlattener() = default; 382 383 // Visitor method overrides. 384 LogicalResult visitMulExpr(AffineBinaryOpExpr expr); 385 LogicalResult visitAddExpr(AffineBinaryOpExpr expr); 386 LogicalResult visitDimExpr(AffineDimExpr expr); 387 LogicalResult visitSymbolExpr(AffineSymbolExpr expr); 388 LogicalResult visitConstantExpr(AffineConstantExpr expr); 389 LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr); 390 LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr); 391 392 // 393 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 394 // 395 // A mod expression "expr mod c" is thus flattened by introducing a new local 396 // variable q (= expr floordiv c), such that expr mod c is replaced with 397 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. 398 LogicalResult visitModExpr(AffineBinaryOpExpr expr); 399 400 protected: 401 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 402 // The local identifier added is always a floordiv of a pure add/mul affine 403 // function of other identifiers, coefficients of which are specified in 404 // dividend and with respect to a positive constant divisor. localExpr is the 405 // simplified tree expression (AffineExpr) corresponding to the quantifier. 406 virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, 407 AffineExpr localExpr); 408 409 /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul 410 /// expr) when the rhs is a symbolic expression. The local identifier added 411 /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine 412 /// function of other identifiers, coefficients of which are specified in the 413 /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a 414 /// symbolic rhs expression. `localExpr` is the simplified tree expression 415 /// (AffineExpr) corresponding to the quantifier. 416 virtual LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs, 417 ArrayRef<int64_t> rhs, 418 AffineExpr localExpr); 419 420 private: 421 /// Adds `localExpr`, which may be mod, ceildiv, floordiv or mod expression 422 /// representing the affine expression corresponding to the quantifier 423 /// introduced as the local variable corresponding to `localExpr`. If the 424 /// quantifier is already present, we put the coefficient in the proper index 425 /// of `result`, otherwise we add a new local variable and put the coefficient 426 /// there. 427 LogicalResult addLocalVariableSemiAffine(ArrayRef<int64_t> lhs, 428 ArrayRef<int64_t> rhs, 429 AffineExpr localExpr, 430 SmallVectorImpl<int64_t> &result, 431 unsigned long resultSize); 432 433 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 434 // A floordiv is thus flattened by introducing a new local variable q, and 435 // replacing that expression with 'q' while adding the constraints 436 // c * q <= expr <= c * q + c - 1 to localVarCst (done by 437 // IntegerRelation::addLocalFloorDiv). 438 // 439 // A ceildiv is similarly flattened: 440 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c 441 LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil); 442 443 int findLocalId(AffineExpr localExpr); 444 getNumCols()445 inline unsigned getNumCols() const { 446 return numDims + numSymbols + numLocals + 1; 447 } getConstantIndex()448 inline unsigned getConstantIndex() const { return getNumCols() - 1; } getLocalVarStartIndex()449 inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; } getSymbolStartIndex()450 inline unsigned getSymbolStartIndex() const { return numDims; } getDimStartIndex()451 inline unsigned getDimStartIndex() const { return 0; } 452 }; 453 454 } // namespace mlir 455 456 #endif // MLIR_IR_AFFINEEXPRVISITOR_H 457