1 //===- Utils.cpp ---- Misc utilities for analysis -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous analysis routines for non-loop IR 10 // structures. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/Analysis/Utils.h" 15 #include "mlir/Analysis/Presburger/PresburgerRelation.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/Interfaces/CallInterfaces.h" 24 #include "llvm/ADT/SetVector.h" 25 #include "llvm/ADT/SmallPtrSet.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <optional> 29 30 #define DEBUG_TYPE "analysis-utils" 31 32 using namespace mlir; 33 using namespace affine; 34 using namespace presburger; 35 36 using llvm::SmallDenseMap; 37 38 using Node = MemRefDependenceGraph::Node; 39 40 // LoopNestStateCollector walks loop nests and collects load and store 41 // operations, and whether or not a region holding op other than ForOp and IfOp 42 // was encountered in the loop nest. 43 void LoopNestStateCollector::collect(Operation *opToWalk) { 44 opToWalk->walk([&](Operation *op) { 45 if (isa<AffineForOp>(op)) 46 forOps.push_back(cast<AffineForOp>(op)); 47 else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op)) 48 hasNonAffineRegionOp = true; 49 else if (isa<AffineReadOpInterface>(op)) 50 loadOpInsts.push_back(op); 51 else if (isa<AffineWriteOpInterface>(op)) 52 storeOpInsts.push_back(op); 53 }); 54 } 55 56 // Returns the load op count for 'memref'. 57 unsigned Node::getLoadOpCount(Value memref) const { 58 unsigned loadOpCount = 0; 59 for (Operation *loadOp : loads) { 60 if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef()) 61 ++loadOpCount; 62 } 63 return loadOpCount; 64 } 65 66 // Returns the store op count for 'memref'. 67 unsigned Node::getStoreOpCount(Value memref) const { 68 unsigned storeOpCount = 0; 69 for (Operation *storeOp : stores) { 70 if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef()) 71 ++storeOpCount; 72 } 73 return storeOpCount; 74 } 75 76 // Returns all store ops in 'storeOps' which access 'memref'. 77 void Node::getStoreOpsForMemref(Value memref, 78 SmallVectorImpl<Operation *> *storeOps) const { 79 for (Operation *storeOp : stores) { 80 if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef()) 81 storeOps->push_back(storeOp); 82 } 83 } 84 85 // Returns all load ops in 'loadOps' which access 'memref'. 86 void Node::getLoadOpsForMemref(Value memref, 87 SmallVectorImpl<Operation *> *loadOps) const { 88 for (Operation *loadOp : loads) { 89 if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef()) 90 loadOps->push_back(loadOp); 91 } 92 } 93 94 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node 95 // has at least one load and store operation. 96 void Node::getLoadAndStoreMemrefSet( 97 DenseSet<Value> *loadAndStoreMemrefSet) const { 98 llvm::SmallDenseSet<Value, 2> loadMemrefs; 99 for (Operation *loadOp : loads) { 100 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef()); 101 } 102 for (Operation *storeOp : stores) { 103 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef(); 104 if (loadMemrefs.count(memref) > 0) 105 loadAndStoreMemrefSet->insert(memref); 106 } 107 } 108 109 // Initializes the data dependence graph by walking operations in `block`. 110 // Assigns each node in the graph a node id based on program order in 'f'. 111 bool MemRefDependenceGraph::init() { 112 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); 113 // Map from a memref to the set of ids of the nodes that have ops accessing 114 // the memref. 115 DenseMap<Value, SetVector<unsigned>> memrefAccesses; 116 117 DenseMap<Operation *, unsigned> forToNodeMap; 118 for (Operation &op : block) { 119 if (dyn_cast<AffineForOp>(op)) { 120 // Create graph node 'id' to represent top-level 'forOp' and record 121 // all loads and store accesses it contains. 122 LoopNestStateCollector collector; 123 collector.collect(&op); 124 // Return false if a region holding op other than 'affine.for' and 125 // 'affine.if' was found (not currently supported). 126 if (collector.hasNonAffineRegionOp) 127 return false; 128 Node node(nextNodeId++, &op); 129 for (auto *opInst : collector.loadOpInsts) { 130 node.loads.push_back(opInst); 131 auto memref = cast<AffineReadOpInterface>(opInst).getMemRef(); 132 memrefAccesses[memref].insert(node.id); 133 } 134 for (auto *opInst : collector.storeOpInsts) { 135 node.stores.push_back(opInst); 136 auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef(); 137 memrefAccesses[memref].insert(node.id); 138 } 139 forToNodeMap[&op] = node.id; 140 nodes.insert({node.id, node}); 141 } else if (dyn_cast<AffineReadOpInterface>(op)) { 142 // Create graph node for top-level load op. 143 Node node(nextNodeId++, &op); 144 node.loads.push_back(&op); 145 auto memref = cast<AffineReadOpInterface>(op).getMemRef(); 146 memrefAccesses[memref].insert(node.id); 147 nodes.insert({node.id, node}); 148 } else if (dyn_cast<AffineWriteOpInterface>(op)) { 149 // Create graph node for top-level store op. 150 Node node(nextNodeId++, &op); 151 node.stores.push_back(&op); 152 auto memref = cast<AffineWriteOpInterface>(op).getMemRef(); 153 memrefAccesses[memref].insert(node.id); 154 nodes.insert({node.id, node}); 155 } else if (op.getNumResults() > 0 && !op.use_empty()) { 156 // Create graph node for top-level producer of SSA values, which 157 // could be used by loop nest nodes. 158 Node node(nextNodeId++, &op); 159 nodes.insert({node.id, node}); 160 } else if (!isMemoryEffectFree(&op) && 161 (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) { 162 // Create graph node for top-level op unless it is known to be 163 // memory-effect free. This covers all unknown/unregistered ops, 164 // non-affine ops with memory effects, and region-holding ops with a 165 // well-defined control flow. During the fusion validity checks, we look 166 // for non-affine ops on the path from source to destination, at which 167 // point we check which memrefs if any are used in the region. 168 Node node(nextNodeId++, &op); 169 nodes.insert({node.id, node}); 170 } else if (op.getNumRegions() != 0) { 171 // Return false if non-handled/unknown region-holding ops are found. We 172 // won't know what such ops do or what its regions mean; for e.g., it may 173 // not be an imperative op. 174 LLVM_DEBUG(llvm::dbgs() 175 << "MDG init failed; unknown region-holding op found!\n"); 176 return false; 177 } 178 } 179 180 for (auto &idAndNode : nodes) { 181 LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" 182 << *(idAndNode.second.op) << "\n"); 183 (void)idAndNode; 184 } 185 186 // Add dependence edges between nodes which produce SSA values and their 187 // users. Load ops can be considered as the ones producing SSA values. 188 for (auto &idAndNode : nodes) { 189 const Node &node = idAndNode.second; 190 // Stores don't define SSA values, skip them. 191 if (!node.stores.empty()) 192 continue; 193 Operation *opInst = node.op; 194 for (Value value : opInst->getResults()) { 195 for (Operation *user : value.getUsers()) { 196 // Ignore users outside of the block. 197 if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() != 198 &block) 199 continue; 200 SmallVector<AffineForOp, 4> loops; 201 getAffineForIVs(*user, &loops); 202 // Find the surrounding affine.for nested immediately within the 203 // block. 204 auto *it = llvm::find_if(loops, [&](AffineForOp loop) { 205 return loop->getBlock() == █ 206 }); 207 if (it == loops.end()) 208 continue; 209 assert(forToNodeMap.count(*it) > 0 && "missing mapping"); 210 unsigned userLoopNestId = forToNodeMap[*it]; 211 addEdge(node.id, userLoopNestId, value); 212 } 213 } 214 } 215 216 // Walk memref access lists and add graph edges between dependent nodes. 217 for (auto &memrefAndList : memrefAccesses) { 218 unsigned n = memrefAndList.second.size(); 219 for (unsigned i = 0; i < n; ++i) { 220 unsigned srcId = memrefAndList.second[i]; 221 bool srcHasStore = 222 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; 223 for (unsigned j = i + 1; j < n; ++j) { 224 unsigned dstId = memrefAndList.second[j]; 225 bool dstHasStore = 226 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; 227 if (srcHasStore || dstHasStore) 228 addEdge(srcId, dstId, memrefAndList.first); 229 } 230 } 231 } 232 return true; 233 } 234 235 // Returns the graph node for 'id'. 236 Node *MemRefDependenceGraph::getNode(unsigned id) { 237 auto it = nodes.find(id); 238 assert(it != nodes.end()); 239 return &it->second; 240 } 241 242 // Returns the graph node for 'forOp'. 243 Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) { 244 for (auto &idAndNode : nodes) 245 if (idAndNode.second.op == forOp) 246 return &idAndNode.second; 247 return nullptr; 248 } 249 250 // Adds a node with 'op' to the graph and returns its unique identifier. 251 unsigned MemRefDependenceGraph::addNode(Operation *op) { 252 Node node(nextNodeId++, op); 253 nodes.insert({node.id, node}); 254 return node.id; 255 } 256 257 // Remove node 'id' (and its associated edges) from graph. 258 void MemRefDependenceGraph::removeNode(unsigned id) { 259 // Remove each edge in 'inEdges[id]'. 260 if (inEdges.count(id) > 0) { 261 SmallVector<Edge, 2> oldInEdges = inEdges[id]; 262 for (auto &inEdge : oldInEdges) { 263 removeEdge(inEdge.id, id, inEdge.value); 264 } 265 } 266 // Remove each edge in 'outEdges[id]'. 267 if (outEdges.count(id) > 0) { 268 SmallVector<Edge, 2> oldOutEdges = outEdges[id]; 269 for (auto &outEdge : oldOutEdges) { 270 removeEdge(id, outEdge.id, outEdge.value); 271 } 272 } 273 // Erase remaining node state. 274 inEdges.erase(id); 275 outEdges.erase(id); 276 nodes.erase(id); 277 } 278 279 // Returns true if node 'id' writes to any memref which escapes (or is an 280 // argument to) the block. Returns false otherwise. 281 bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) { 282 Node *node = getNode(id); 283 for (auto *storeOpInst : node->stores) { 284 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 285 auto *op = memref.getDefiningOp(); 286 // Return true if 'memref' is a block argument. 287 if (!op) 288 return true; 289 // Return true if any use of 'memref' does not deference it in an affine 290 // way. 291 for (auto *user : memref.getUsers()) 292 if (!isa<AffineMapAccessInterface>(*user)) 293 return true; 294 } 295 return false; 296 } 297 298 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which 299 // is for 'value' if non-null, or for any value otherwise. Returns false 300 // otherwise. 301 bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId, 302 Value value) { 303 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { 304 return false; 305 } 306 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { 307 return edge.id == dstId && (!value || edge.value == value); 308 }); 309 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { 310 return edge.id == srcId && (!value || edge.value == value); 311 }); 312 return hasOutEdge && hasInEdge; 313 } 314 315 // Adds an edge from node 'srcId' to node 'dstId' for 'value'. 316 void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId, 317 Value value) { 318 if (!hasEdge(srcId, dstId, value)) { 319 outEdges[srcId].push_back({dstId, value}); 320 inEdges[dstId].push_back({srcId, value}); 321 if (isa<MemRefType>(value.getType())) 322 memrefEdgeCount[value]++; 323 } 324 } 325 326 // Removes an edge from node 'srcId' to node 'dstId' for 'value'. 327 void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId, 328 Value value) { 329 assert(inEdges.count(dstId) > 0); 330 assert(outEdges.count(srcId) > 0); 331 if (isa<MemRefType>(value.getType())) { 332 assert(memrefEdgeCount.count(value) > 0); 333 memrefEdgeCount[value]--; 334 } 335 // Remove 'srcId' from 'inEdges[dstId]'. 336 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { 337 if ((*it).id == srcId && (*it).value == value) { 338 inEdges[dstId].erase(it); 339 break; 340 } 341 } 342 // Remove 'dstId' from 'outEdges[srcId]'. 343 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) { 344 if ((*it).id == dstId && (*it).value == value) { 345 outEdges[srcId].erase(it); 346 break; 347 } 348 } 349 } 350 351 // Returns true if there is a path in the dependence graph from node 'srcId' 352 // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the 353 // operations that the edges connected are expected to be from the same block. 354 bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) { 355 // Worklist state is: <node-id, next-output-edge-index-to-visit> 356 SmallVector<std::pair<unsigned, unsigned>, 4> worklist; 357 worklist.push_back({srcId, 0}); 358 Operation *dstOp = getNode(dstId)->op; 359 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. 360 while (!worklist.empty()) { 361 auto &idAndIndex = worklist.back(); 362 // Return true if we have reached 'dstId'. 363 if (idAndIndex.first == dstId) 364 return true; 365 // Pop and continue if node has no out edges, or if all out edges have 366 // already been visited. 367 if (outEdges.count(idAndIndex.first) == 0 || 368 idAndIndex.second == outEdges[idAndIndex.first].size()) { 369 worklist.pop_back(); 370 continue; 371 } 372 // Get graph edge to traverse. 373 Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; 374 // Increment next output edge index for 'idAndIndex'. 375 ++idAndIndex.second; 376 // Add node at 'edge.id' to the worklist. We don't need to consider 377 // nodes that are "after" dstId in the containing block; one can't have a 378 // path to `dstId` from any of those nodes. 379 bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op); 380 if (!afterDst && edge.id != idAndIndex.first) 381 worklist.push_back({edge.id, 0}); 382 } 383 return false; 384 } 385 386 // Returns the input edge count for node 'id' and 'memref' from src nodes 387 // which access 'memref' with a store operation. 388 unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id, 389 Value memref) { 390 unsigned inEdgeCount = 0; 391 if (inEdges.count(id) > 0) 392 for (auto &inEdge : inEdges[id]) 393 if (inEdge.value == memref) { 394 Node *srcNode = getNode(inEdge.id); 395 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' 396 if (srcNode->getStoreOpCount(memref) > 0) 397 ++inEdgeCount; 398 } 399 return inEdgeCount; 400 } 401 402 // Returns the output edge count for node 'id' and 'memref' (if non-null), 403 // otherwise returns the total output edge count from node 'id'. 404 unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) { 405 unsigned outEdgeCount = 0; 406 if (outEdges.count(id) > 0) 407 for (auto &outEdge : outEdges[id]) 408 if (!memref || outEdge.value == memref) 409 ++outEdgeCount; 410 return outEdgeCount; 411 } 412 413 /// Return all nodes which define SSA values used in node 'id'. 414 void MemRefDependenceGraph::gatherDefiningNodes( 415 unsigned id, DenseSet<unsigned> &definingNodes) { 416 for (MemRefDependenceGraph::Edge edge : inEdges[id]) 417 // By definition of edge, if the edge value is a non-memref value, 418 // then the dependence is between a graph node which defines an SSA value 419 // and another graph node which uses the SSA value. 420 if (!isa<MemRefType>(edge.value.getType())) 421 definingNodes.insert(edge.id); 422 } 423 424 // Computes and returns an insertion point operation, before which the 425 // the fused <srcId, dstId> loop nest can be inserted while preserving 426 // dependences. Returns nullptr if no such insertion point is found. 427 Operation * 428 MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId, 429 unsigned dstId) { 430 if (outEdges.count(srcId) == 0) 431 return getNode(dstId)->op; 432 433 // Skip if there is any defining node of 'dstId' that depends on 'srcId'. 434 DenseSet<unsigned> definingNodes; 435 gatherDefiningNodes(dstId, definingNodes); 436 if (llvm::any_of(definingNodes, 437 [&](unsigned id) { return hasDependencePath(srcId, id); })) { 438 LLVM_DEBUG(llvm::dbgs() 439 << "Can't fuse: a defining op with a user in the dst " 440 "loop has dependence from the src loop\n"); 441 return nullptr; 442 } 443 444 // Build set of insts in range (srcId, dstId) which depend on 'srcId'. 445 SmallPtrSet<Operation *, 2> srcDepInsts; 446 for (auto &outEdge : outEdges[srcId]) 447 if (outEdge.id != dstId) 448 srcDepInsts.insert(getNode(outEdge.id)->op); 449 450 // Build set of insts in range (srcId, dstId) on which 'dstId' depends. 451 SmallPtrSet<Operation *, 2> dstDepInsts; 452 for (auto &inEdge : inEdges[dstId]) 453 if (inEdge.id != srcId) 454 dstDepInsts.insert(getNode(inEdge.id)->op); 455 456 Operation *srcNodeInst = getNode(srcId)->op; 457 Operation *dstNodeInst = getNode(dstId)->op; 458 459 // Computing insertion point: 460 // *) Walk all operation positions in Block operation list in the 461 // range (src, dst). For each operation 'op' visited in this search: 462 // *) Store in 'firstSrcDepPos' the first position where 'op' has a 463 // dependence edge from 'srcNode'. 464 // *) Store in 'lastDstDepPost' the last position where 'op' has a 465 // dependence edge to 'dstNode'. 466 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the 467 // operation insertion point (or return null pointer if no such 468 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). 469 SmallVector<Operation *, 2> depInsts; 470 std::optional<unsigned> firstSrcDepPos; 471 std::optional<unsigned> lastDstDepPos; 472 unsigned pos = 0; 473 for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); 474 it != Block::iterator(dstNodeInst); ++it) { 475 Operation *op = &(*it); 476 if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt) 477 firstSrcDepPos = pos; 478 if (dstDepInsts.count(op) > 0) 479 lastDstDepPos = pos; 480 depInsts.push_back(op); 481 ++pos; 482 } 483 484 if (firstSrcDepPos.has_value()) { 485 if (lastDstDepPos.has_value()) { 486 if (*firstSrcDepPos <= *lastDstDepPos) { 487 // No valid insertion point exists which preserves dependences. 488 return nullptr; 489 } 490 } 491 // Return the insertion point at 'firstSrcDepPos'. 492 return depInsts[*firstSrcDepPos]; 493 } 494 // No dependence targets in range (or only dst deps in range), return 495 // 'dstNodInst' insertion point. 496 return dstNodeInst; 497 } 498 499 // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, 500 // taking into account that: 501 // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, 502 // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a 503 // private memref. 504 void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId, 505 const DenseSet<Value> &privateMemRefs, 506 bool removeSrcId) { 507 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. 508 if (inEdges.count(srcId) > 0) { 509 SmallVector<Edge, 2> oldInEdges = inEdges[srcId]; 510 for (auto &inEdge : oldInEdges) { 511 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. 512 if (privateMemRefs.count(inEdge.value) == 0) 513 addEdge(inEdge.id, dstId, inEdge.value); 514 } 515 } 516 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. 517 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. 518 if (outEdges.count(srcId) > 0) { 519 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId]; 520 for (auto &outEdge : oldOutEdges) { 521 // Remove any out edges from 'srcId' to 'dstId' across memrefs. 522 if (outEdge.id == dstId) 523 removeEdge(srcId, outEdge.id, outEdge.value); 524 else if (removeSrcId) { 525 addEdge(dstId, outEdge.id, outEdge.value); 526 removeEdge(srcId, outEdge.id, outEdge.value); 527 } 528 } 529 } 530 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being 531 // replaced by a private memref). These edges could come from nodes 532 // other than 'srcId' which were removed in the previous step. 533 if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { 534 SmallVector<Edge, 2> oldInEdges = inEdges[dstId]; 535 for (auto &inEdge : oldInEdges) 536 if (privateMemRefs.count(inEdge.value) > 0) 537 removeEdge(inEdge.id, dstId, inEdge.value); 538 } 539 } 540 541 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion 542 // of sibling node 'sibId' into node 'dstId'. 543 void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) { 544 // For each edge in 'inEdges[sibId]': 545 // *) Add new edge from source node 'inEdge.id' to 'dstNode'. 546 // *) Remove edge from source node 'inEdge.id' to 'sibNode'. 547 if (inEdges.count(sibId) > 0) { 548 SmallVector<Edge, 2> oldInEdges = inEdges[sibId]; 549 for (auto &inEdge : oldInEdges) { 550 addEdge(inEdge.id, dstId, inEdge.value); 551 removeEdge(inEdge.id, sibId, inEdge.value); 552 } 553 } 554 555 // For each edge in 'outEdges[sibId]' to node 'id' 556 // *) Add new edge from 'dstId' to 'outEdge.id'. 557 // *) Remove edge from 'sibId' to 'outEdge.id'. 558 if (outEdges.count(sibId) > 0) { 559 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId]; 560 for (auto &outEdge : oldOutEdges) { 561 addEdge(dstId, outEdge.id, outEdge.value); 562 removeEdge(sibId, outEdge.id, outEdge.value); 563 } 564 } 565 } 566 567 // Adds ops in 'loads' and 'stores' to node at 'id'. 568 void MemRefDependenceGraph::addToNode( 569 unsigned id, const SmallVectorImpl<Operation *> &loads, 570 const SmallVectorImpl<Operation *> &stores) { 571 Node *node = getNode(id); 572 llvm::append_range(node->loads, loads); 573 llvm::append_range(node->stores, stores); 574 } 575 576 void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) { 577 Node *node = getNode(id); 578 node->loads.clear(); 579 node->stores.clear(); 580 } 581 582 // Calls 'callback' for each input edge incident to node 'id' which carries a 583 // memref dependence. 584 void MemRefDependenceGraph::forEachMemRefInputEdge( 585 unsigned id, const std::function<void(Edge)> &callback) { 586 if (inEdges.count(id) > 0) 587 forEachMemRefEdge(inEdges[id], callback); 588 } 589 590 // Calls 'callback' for each output edge from node 'id' which carries a 591 // memref dependence. 592 void MemRefDependenceGraph::forEachMemRefOutputEdge( 593 unsigned id, const std::function<void(Edge)> &callback) { 594 if (outEdges.count(id) > 0) 595 forEachMemRefEdge(outEdges[id], callback); 596 } 597 598 // Calls 'callback' for each edge in 'edges' which carries a memref 599 // dependence. 600 void MemRefDependenceGraph::forEachMemRefEdge( 601 ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) { 602 for (const auto &edge : edges) { 603 // Skip if 'edge' is not a memref dependence edge. 604 if (!isa<MemRefType>(edge.value.getType())) 605 continue; 606 assert(nodes.count(edge.id) > 0); 607 // Skip if 'edge.id' is not a loop nest. 608 if (!isa<AffineForOp>(getNode(edge.id)->op)) 609 continue; 610 // Visit current input edge 'edge'. 611 callback(edge); 612 } 613 } 614 615 void MemRefDependenceGraph::print(raw_ostream &os) const { 616 os << "\nMemRefDependenceGraph\n"; 617 os << "\nNodes:\n"; 618 for (const auto &idAndNode : nodes) { 619 os << "Node: " << idAndNode.first << "\n"; 620 auto it = inEdges.find(idAndNode.first); 621 if (it != inEdges.end()) { 622 for (const auto &e : it->second) 623 os << " InEdge: " << e.id << " " << e.value << "\n"; 624 } 625 it = outEdges.find(idAndNode.first); 626 if (it != outEdges.end()) { 627 for (const auto &e : it->second) 628 os << " OutEdge: " << e.id << " " << e.value << "\n"; 629 } 630 } 631 } 632 633 void mlir::affine::getAffineForIVs(Operation &op, 634 SmallVectorImpl<AffineForOp> *loops) { 635 auto *currOp = op.getParentOp(); 636 AffineForOp currAffineForOp; 637 // Traverse up the hierarchy collecting all 'affine.for' operation while 638 // skipping over 'affine.if' operations. 639 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) { 640 if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp)) 641 loops->push_back(currAffineForOp); 642 currOp = currOp->getParentOp(); 643 } 644 std::reverse(loops->begin(), loops->end()); 645 } 646 647 void mlir::affine::getEnclosingAffineOps(Operation &op, 648 SmallVectorImpl<Operation *> *ops) { 649 ops->clear(); 650 Operation *currOp = op.getParentOp(); 651 652 // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and 653 // affine.parallel operations. 654 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) { 655 if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp)) 656 ops->push_back(currOp); 657 currOp = currOp->getParentOp(); 658 } 659 std::reverse(ops->begin(), ops->end()); 660 } 661 662 // Populates 'cst' with FlatAffineValueConstraints which represent original 663 // domain of the loop bounds that define 'ivs'. 664 LogicalResult ComputationSliceState::getSourceAsConstraints( 665 FlatAffineValueConstraints &cst) const { 666 assert(!ivs.empty() && "Cannot have a slice without its IVs"); 667 cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0, 668 /*numLocals=*/0, ivs); 669 for (Value iv : ivs) { 670 AffineForOp loop = getForInductionVarOwner(iv); 671 assert(loop && "Expected affine for"); 672 if (failed(cst.addAffineForOpDomain(loop))) 673 return failure(); 674 } 675 return success(); 676 } 677 678 // Populates 'cst' with FlatAffineValueConstraints which represent slice bounds. 679 LogicalResult 680 ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) const { 681 assert(!lbOperands.empty()); 682 // Adds src 'ivs' as dimension variables in 'cst'. 683 unsigned numDims = ivs.size(); 684 // Adds operands (dst ivs and symbols) as symbols in 'cst'. 685 unsigned numSymbols = lbOperands[0].size(); 686 687 SmallVector<Value, 4> values(ivs); 688 // Append 'ivs' then 'operands' to 'values'. 689 values.append(lbOperands[0].begin(), lbOperands[0].end()); 690 *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values); 691 692 // Add loop bound constraints for values which are loop IVs of the destination 693 // of fusion and equality constraints for symbols which are constants. 694 for (unsigned i = numDims, end = values.size(); i < end; ++i) { 695 Value value = values[i]; 696 assert(cst->containsVar(value) && "value expected to be present"); 697 if (isValidSymbol(value)) { 698 // Check if the symbol is a constant. 699 if (std::optional<int64_t> cOp = getConstantIntValue(value)) 700 cst->addBound(BoundType::EQ, value, cOp.value()); 701 } else if (auto loop = getForInductionVarOwner(value)) { 702 if (failed(cst->addAffineForOpDomain(loop))) 703 return failure(); 704 } 705 } 706 707 // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]' 708 LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]); 709 assert(succeeded(ret) && 710 "should not fail as we never have semi-affine slice maps"); 711 (void)ret; 712 return success(); 713 } 714 715 // Clears state bounds and operand state. 716 void ComputationSliceState::clearBounds() { 717 lbs.clear(); 718 ubs.clear(); 719 lbOperands.clear(); 720 ubOperands.clear(); 721 } 722 723 void ComputationSliceState::dump() const { 724 llvm::errs() << "\tIVs:\n"; 725 for (Value iv : ivs) 726 llvm::errs() << "\t\t" << iv << "\n"; 727 728 llvm::errs() << "\tLBs:\n"; 729 for (auto en : llvm::enumerate(lbs)) { 730 llvm::errs() << "\t\t" << en.value() << "\n"; 731 llvm::errs() << "\t\tOperands:\n"; 732 for (Value lbOp : lbOperands[en.index()]) 733 llvm::errs() << "\t\t\t" << lbOp << "\n"; 734 } 735 736 llvm::errs() << "\tUBs:\n"; 737 for (auto en : llvm::enumerate(ubs)) { 738 llvm::errs() << "\t\t" << en.value() << "\n"; 739 llvm::errs() << "\t\tOperands:\n"; 740 for (Value ubOp : ubOperands[en.index()]) 741 llvm::errs() << "\t\t\t" << ubOp << "\n"; 742 } 743 } 744 745 /// Fast check to determine if the computation slice is maximal. Returns true if 746 /// each slice dimension maps to an existing dst dimension and both the src 747 /// and the dst loops for those dimensions have the same bounds. Returns false 748 /// if both the src and the dst loops don't have the same bounds. Returns 749 /// std::nullopt if none of the above can be proven. 750 std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const { 751 assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() && 752 "Unexpected number of lbs, ubs and ivs in slice"); 753 754 for (unsigned i = 0, end = lbs.size(); i < end; ++i) { 755 AffineMap lbMap = lbs[i]; 756 AffineMap ubMap = ubs[i]; 757 758 // Check if this slice is just an equality along this dimension. 759 if (!lbMap || !ubMap || lbMap.getNumResults() != 1 || 760 ubMap.getNumResults() != 1 || 761 lbMap.getResult(0) + 1 != ubMap.getResult(0) || 762 // The condition above will be true for maps describing a single 763 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). 764 // Make sure we skip those cases by checking that the lb result is not 765 // just a constant. 766 isa<AffineConstantExpr>(lbMap.getResult(0))) 767 return std::nullopt; 768 769 // Limited support: we expect the lb result to be just a loop dimension for 770 // now. 771 AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0)); 772 if (!result) 773 return std::nullopt; 774 775 // Retrieve dst loop bounds. 776 AffineForOp dstLoop = 777 getForInductionVarOwner(lbOperands[i][result.getPosition()]); 778 if (!dstLoop) 779 return std::nullopt; 780 AffineMap dstLbMap = dstLoop.getLowerBoundMap(); 781 AffineMap dstUbMap = dstLoop.getUpperBoundMap(); 782 783 // Retrieve src loop bounds. 784 AffineForOp srcLoop = getForInductionVarOwner(ivs[i]); 785 assert(srcLoop && "Expected affine for"); 786 AffineMap srcLbMap = srcLoop.getLowerBoundMap(); 787 AffineMap srcUbMap = srcLoop.getUpperBoundMap(); 788 789 // Limited support: we expect simple src and dst loops with a single 790 // constant component per bound for now. 791 if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 || 792 dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1) 793 return std::nullopt; 794 795 AffineExpr srcLbResult = srcLbMap.getResult(0); 796 AffineExpr dstLbResult = dstLbMap.getResult(0); 797 AffineExpr srcUbResult = srcUbMap.getResult(0); 798 AffineExpr dstUbResult = dstUbMap.getResult(0); 799 if (!isa<AffineConstantExpr>(srcLbResult) || 800 !isa<AffineConstantExpr>(srcUbResult) || 801 !isa<AffineConstantExpr>(dstLbResult) || 802 !isa<AffineConstantExpr>(dstUbResult)) 803 return std::nullopt; 804 805 // Check if src and dst loop bounds are the same. If not, we can guarantee 806 // that the slice is not maximal. 807 if (srcLbResult != dstLbResult || srcUbResult != dstUbResult || 808 srcLoop.getStep() != dstLoop.getStep()) 809 return false; 810 } 811 812 return true; 813 } 814 815 /// Returns true if it is deterministically verified that the original iteration 816 /// space of the slice is contained within the new iteration space that is 817 /// created after fusing 'this' slice into its destination. 818 std::optional<bool> ComputationSliceState::isSliceValid() const { 819 // Fast check to determine if the slice is valid. If the following conditions 820 // are verified to be true, slice is declared valid by the fast check: 821 // 1. Each slice loop is a single iteration loop bound in terms of a single 822 // destination loop IV. 823 // 2. Loop bounds of the destination loop IV (from above) and those of the 824 // source loop IV are exactly the same. 825 // If the fast check is inconclusive or false, we proceed with a more 826 // expensive analysis. 827 // TODO: Store the result of the fast check, as it might be used again in 828 // `canRemoveSrcNodeAfterFusion`. 829 std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck(); 830 if (isValidFastCheck && *isValidFastCheck) 831 return true; 832 833 // Create constraints for the source loop nest using which slice is computed. 834 FlatAffineValueConstraints srcConstraints; 835 // TODO: Store the source's domain to avoid computation at each depth. 836 if (failed(getSourceAsConstraints(srcConstraints))) { 837 LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n"); 838 return std::nullopt; 839 } 840 // As the set difference utility currently cannot handle symbols in its 841 // operands, validity of the slice cannot be determined. 842 if (srcConstraints.getNumSymbolVars() > 0) { 843 LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n"); 844 return std::nullopt; 845 } 846 // TODO: Handle local vars in the source domains while using the 'projectOut' 847 // utility below. Currently, aligning is not done assuming that there will be 848 // no local vars in the source domain. 849 if (srcConstraints.getNumLocalVars() != 0) { 850 LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n"); 851 return std::nullopt; 852 } 853 854 // Create constraints for the slice loop nest that would be created if the 855 // fusion succeeds. 856 FlatAffineValueConstraints sliceConstraints; 857 if (failed(getAsConstraints(&sliceConstraints))) { 858 LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n"); 859 return std::nullopt; 860 } 861 862 // Projecting out every dimension other than the 'ivs' to express slice's 863 // domain completely in terms of source's IVs. 864 sliceConstraints.projectOut(ivs.size(), 865 sliceConstraints.getNumVars() - ivs.size()); 866 867 LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n"); 868 LLVM_DEBUG(srcConstraints.dump()); 869 LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds " 870 "(expressed in terms of its source's IVs):\n"); 871 LLVM_DEBUG(sliceConstraints.dump()); 872 873 // TODO: Store 'srcSet' to avoid recalculating for each depth. 874 PresburgerSet srcSet(srcConstraints); 875 PresburgerSet sliceSet(sliceConstraints); 876 PresburgerSet diffSet = sliceSet.subtract(srcSet); 877 878 if (!diffSet.isIntegerEmpty()) { 879 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n"); 880 return false; 881 } 882 return true; 883 } 884 885 /// Returns true if the computation slice encloses all the iterations of the 886 /// sliced loop nest. Returns false if it does not. Returns std::nullopt if it 887 /// cannot determine if the slice is maximal or not. 888 std::optional<bool> ComputationSliceState::isMaximal() const { 889 // Fast check to determine if the computation slice is maximal. If the result 890 // is inconclusive, we proceed with a more expensive analysis. 891 std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck(); 892 if (isMaximalFastCheck) 893 return isMaximalFastCheck; 894 895 // Create constraints for the src loop nest being sliced. 896 FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(), 897 /*numSymbols=*/0, 898 /*numLocals=*/0, ivs); 899 for (Value iv : ivs) { 900 AffineForOp loop = getForInductionVarOwner(iv); 901 assert(loop && "Expected affine for"); 902 if (failed(srcConstraints.addAffineForOpDomain(loop))) 903 return std::nullopt; 904 } 905 906 // Create constraints for the slice using the dst loop nest information. We 907 // retrieve existing dst loops from the lbOperands. 908 SmallVector<Value> consumerIVs; 909 for (Value lbOp : lbOperands[0]) 910 if (getForInductionVarOwner(lbOp)) 911 consumerIVs.push_back(lbOp); 912 913 // Add empty IV Values for those new loops that are not equalities and, 914 // therefore, are not yet materialized in the IR. 915 for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i) 916 consumerIVs.push_back(Value()); 917 918 FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(), 919 /*numSymbols=*/0, 920 /*numLocals=*/0, consumerIVs); 921 922 if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0]))) 923 return std::nullopt; 924 925 if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars()) 926 // Constraint dims are different. The integer set difference can't be 927 // computed so we don't know if the slice is maximal. 928 return std::nullopt; 929 930 // Compute the difference between the src loop nest and the slice integer 931 // sets. 932 PresburgerSet srcSet(srcConstraints); 933 PresburgerSet sliceSet(sliceConstraints); 934 PresburgerSet diffSet = srcSet.subtract(sliceSet); 935 return diffSet.isIntegerEmpty(); 936 } 937 938 unsigned MemRefRegion::getRank() const { 939 return cast<MemRefType>(memref.getType()).getRank(); 940 } 941 942 std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape( 943 SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs, 944 SmallVectorImpl<int64_t> *lbDivisors) const { 945 auto memRefType = cast<MemRefType>(memref.getType()); 946 unsigned rank = memRefType.getRank(); 947 if (shape) 948 shape->reserve(rank); 949 950 assert(rank == cst.getNumDimVars() && "inconsistent memref region"); 951 952 // Use a copy of the region constraints that has upper/lower bounds for each 953 // memref dimension with static size added to guard against potential 954 // over-approximation from projection or union bounding box. We may not add 955 // this on the region itself since they might just be redundant constraints 956 // that will need non-trivials means to eliminate. 957 FlatAffineValueConstraints cstWithShapeBounds(cst); 958 for (unsigned r = 0; r < rank; r++) { 959 cstWithShapeBounds.addBound(BoundType::LB, r, 0); 960 int64_t dimSize = memRefType.getDimSize(r); 961 if (ShapedType::isDynamic(dimSize)) 962 continue; 963 cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1); 964 } 965 966 // Find a constant upper bound on the extent of this memref region along each 967 // dimension. 968 int64_t numElements = 1; 969 int64_t diffConstant; 970 int64_t lbDivisor; 971 for (unsigned d = 0; d < rank; d++) { 972 SmallVector<int64_t, 4> lb; 973 std::optional<int64_t> diff = 974 cstWithShapeBounds.getConstantBoundOnDimSize64(d, &lb, &lbDivisor); 975 if (diff.has_value()) { 976 diffConstant = *diff; 977 assert(diffConstant >= 0 && "Dim size bound can't be negative"); 978 assert(lbDivisor > 0); 979 } else { 980 // If no constant bound is found, then it can always be bound by the 981 // memref's dim size if the latter has a constant size along this dim. 982 auto dimSize = memRefType.getDimSize(d); 983 if (dimSize == ShapedType::kDynamic) 984 return std::nullopt; 985 diffConstant = dimSize; 986 // Lower bound becomes 0. 987 lb.resize(cstWithShapeBounds.getNumSymbolVars() + 1, 0); 988 lbDivisor = 1; 989 } 990 numElements *= diffConstant; 991 if (lbs) { 992 lbs->push_back(lb); 993 assert(lbDivisors && "both lbs and lbDivisor or none"); 994 lbDivisors->push_back(lbDivisor); 995 } 996 if (shape) { 997 shape->push_back(diffConstant); 998 } 999 } 1000 return numElements; 1001 } 1002 1003 void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap, 1004 AffineMap &ubMap) const { 1005 assert(pos < cst.getNumDimVars() && "invalid position"); 1006 auto memRefType = cast<MemRefType>(memref.getType()); 1007 unsigned rank = memRefType.getRank(); 1008 1009 assert(rank == cst.getNumDimVars() && "inconsistent memref region"); 1010 1011 auto boundPairs = cst.getLowerAndUpperBound( 1012 pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolVars(), 1013 /*localExprs=*/{}, memRefType.getContext()); 1014 lbMap = boundPairs.first; 1015 ubMap = boundPairs.second; 1016 assert(lbMap && "lower bound for a region must exist"); 1017 assert(ubMap && "upper bound for a region must exist"); 1018 assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank); 1019 assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank); 1020 } 1021 1022 LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) { 1023 assert(memref == other.memref); 1024 return cst.unionBoundingBox(*other.getConstraints()); 1025 } 1026 1027 /// Computes the memory region accessed by this memref with the region 1028 /// represented as constraints symbolic/parametric in 'loopDepth' loops 1029 /// surrounding opInst and any additional Function symbols. 1030 // For example, the memref region for this load operation at loopDepth = 1 will 1031 // be as below: 1032 // 1033 // affine.for %i = 0 to 32 { 1034 // affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { 1035 // load %A[%ii] 1036 // } 1037 // } 1038 // 1039 // region: {memref = %A, write = false, {%i <= m0 <= %i + 7} } 1040 // The last field is a 2-d FlatAffineValueConstraints symbolic in %i. 1041 // 1042 // TODO: extend this to any other memref dereferencing ops 1043 // (dma_start, dma_wait). 1044 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, 1045 const ComputationSliceState *sliceState, 1046 bool addMemRefDimBounds) { 1047 assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) && 1048 "affine read/write op expected"); 1049 1050 MemRefAccess access(op); 1051 memref = access.memref; 1052 write = access.isStore(); 1053 1054 unsigned rank = access.getRank(); 1055 1056 LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op 1057 << "\ndepth: " << loopDepth << "\n";); 1058 1059 // 0-d memrefs. 1060 if (rank == 0) { 1061 SmallVector<Value, 4> ivs; 1062 getAffineIVs(*op, ivs); 1063 assert(loopDepth <= ivs.size() && "invalid 'loopDepth'"); 1064 // The first 'loopDepth' IVs are symbols for this region. 1065 ivs.resize(loopDepth); 1066 // A 0-d memref has a 0-d region. 1067 cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs); 1068 return success(); 1069 } 1070 1071 // Build the constraints for this region. 1072 AffineValueMap accessValueMap; 1073 access.getAccessMap(&accessValueMap); 1074 AffineMap accessMap = accessValueMap.getAffineMap(); 1075 1076 unsigned numDims = accessMap.getNumDims(); 1077 unsigned numSymbols = accessMap.getNumSymbols(); 1078 unsigned numOperands = accessValueMap.getNumOperands(); 1079 // Merge operands with slice operands. 1080 SmallVector<Value, 4> operands; 1081 operands.resize(numOperands); 1082 for (unsigned i = 0; i < numOperands; ++i) 1083 operands[i] = accessValueMap.getOperand(i); 1084 1085 if (sliceState != nullptr) { 1086 operands.reserve(operands.size() + sliceState->lbOperands[0].size()); 1087 // Append slice operands to 'operands' as symbols. 1088 for (auto extraOperand : sliceState->lbOperands[0]) { 1089 if (!llvm::is_contained(operands, extraOperand)) { 1090 operands.push_back(extraOperand); 1091 numSymbols++; 1092 } 1093 } 1094 } 1095 // We'll first associate the dims and symbols of the access map to the dims 1096 // and symbols resp. of cst. This will change below once cst is 1097 // fully constructed out. 1098 cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands); 1099 1100 // Add equality constraints. 1101 // Add inequalities for loop lower/upper bounds. 1102 for (unsigned i = 0; i < numDims + numSymbols; ++i) { 1103 auto operand = operands[i]; 1104 if (auto affineFor = getForInductionVarOwner(operand)) { 1105 // Note that cst can now have more dimensions than accessMap if the 1106 // bounds expressions involve outer loops or other symbols. 1107 // TODO: rewrite this to use getInstIndexSet; this way 1108 // conditionals will be handled when the latter supports it. 1109 if (failed(cst.addAffineForOpDomain(affineFor))) 1110 return failure(); 1111 } else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) { 1112 if (failed(cst.addAffineParallelOpDomain(parallelOp))) 1113 return failure(); 1114 } else if (isValidSymbol(operand)) { 1115 // Check if the symbol is a constant. 1116 Value symbol = operand; 1117 if (auto constVal = getConstantIntValue(symbol)) 1118 cst.addBound(BoundType::EQ, symbol, constVal.value()); 1119 } else { 1120 LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value"); 1121 return failure(); 1122 } 1123 } 1124 1125 // Add lower/upper bounds on loop IVs using bounds from 'sliceState'. 1126 if (sliceState != nullptr) { 1127 // Add dim and symbol slice operands. 1128 for (auto operand : sliceState->lbOperands[0]) { 1129 cst.addInductionVarOrTerminalSymbol(operand); 1130 } 1131 // Add upper/lower bounds from 'sliceState' to 'cst'. 1132 LogicalResult ret = 1133 cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs, 1134 sliceState->lbOperands[0]); 1135 assert(succeeded(ret) && 1136 "should not fail as we never have semi-affine slice maps"); 1137 (void)ret; 1138 } 1139 1140 // Add access function equalities to connect loop IVs to data dimensions. 1141 if (failed(cst.composeMap(&accessValueMap))) { 1142 op->emitError("getMemRefRegion: compose affine map failed"); 1143 LLVM_DEBUG(accessValueMap.getAffineMap().dump()); 1144 return failure(); 1145 } 1146 1147 // Set all variables appearing after the first 'rank' variables as 1148 // symbolic variables - so that the ones corresponding to the memref 1149 // dimensions are the dimensional variables for the memref region. 1150 cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank); 1151 1152 // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which 1153 // this memref region is symbolic. 1154 SmallVector<Value, 4> enclosingIVs; 1155 getAffineIVs(*op, enclosingIVs); 1156 assert(loopDepth <= enclosingIVs.size() && "invalid loop depth"); 1157 enclosingIVs.resize(loopDepth); 1158 SmallVector<Value, 4> vars; 1159 cst.getValues(cst.getNumDimVars(), cst.getNumDimAndSymbolVars(), &vars); 1160 for (Value var : vars) { 1161 if ((isAffineInductionVar(var)) && !llvm::is_contained(enclosingIVs, var)) { 1162 cst.projectOut(var); 1163 } 1164 } 1165 1166 // Project out any local variables (these would have been added for any 1167 // mod/divs). 1168 cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars()); 1169 1170 // Constant fold any symbolic variables. 1171 cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(), 1172 /*num=*/cst.getNumSymbolVars()); 1173 1174 assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format"); 1175 1176 // Add upper/lower bounds for each memref dimension with static size 1177 // to guard against potential over-approximation from projection. 1178 // TODO: Support dynamic memref dimensions. 1179 if (addMemRefDimBounds) { 1180 auto memRefType = cast<MemRefType>(memref.getType()); 1181 for (unsigned r = 0; r < rank; r++) { 1182 cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0); 1183 if (memRefType.isDynamicDim(r)) 1184 continue; 1185 cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1); 1186 } 1187 } 1188 cst.removeTrivialRedundancy(); 1189 1190 LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); 1191 LLVM_DEBUG(cst.dump()); 1192 return success(); 1193 } 1194 1195 std::optional<int64_t> 1196 mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) { 1197 auto elementType = memRefType.getElementType(); 1198 1199 unsigned sizeInBits; 1200 if (elementType.isIntOrFloat()) { 1201 sizeInBits = elementType.getIntOrFloatBitWidth(); 1202 } else if (auto vectorType = dyn_cast<VectorType>(elementType)) { 1203 if (vectorType.getElementType().isIntOrFloat()) 1204 sizeInBits = 1205 vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); 1206 else 1207 return std::nullopt; 1208 } else { 1209 return std::nullopt; 1210 } 1211 return llvm::divideCeil(sizeInBits, 8); 1212 } 1213 1214 // Returns the size of the region. 1215 std::optional<int64_t> MemRefRegion::getRegionSize() { 1216 auto memRefType = cast<MemRefType>(memref.getType()); 1217 1218 if (!memRefType.getLayout().isIdentity()) { 1219 LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); 1220 return false; 1221 } 1222 1223 // Indices to use for the DmaStart op. 1224 // Indices for the original memref being DMAed from/to. 1225 SmallVector<Value, 4> memIndices; 1226 // Indices for the faster buffer being DMAed into/from. 1227 SmallVector<Value, 4> bufIndices; 1228 1229 // Compute the extents of the buffer. 1230 std::optional<int64_t> numElements = getConstantBoundingSizeAndShape(); 1231 if (!numElements) { 1232 LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); 1233 return std::nullopt; 1234 } 1235 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType); 1236 if (!eltSize) 1237 return std::nullopt; 1238 return *eltSize * *numElements; 1239 } 1240 1241 /// Returns the size of memref data in bytes if it's statically shaped, 1242 /// std::nullopt otherwise. If the element of the memref has vector type, takes 1243 /// into account size of the vector as well. 1244 // TODO: improve/complete this when we have target data. 1245 std::optional<uint64_t> 1246 mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) { 1247 if (!memRefType.hasStaticShape()) 1248 return std::nullopt; 1249 auto elementType = memRefType.getElementType(); 1250 if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType)) 1251 return std::nullopt; 1252 1253 auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType); 1254 if (!sizeInBytes) 1255 return std::nullopt; 1256 for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) { 1257 sizeInBytes = *sizeInBytes * memRefType.getDimSize(i); 1258 } 1259 return sizeInBytes; 1260 } 1261 1262 template <typename LoadOrStoreOp> 1263 LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp, 1264 bool emitError) { 1265 static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface, 1266 AffineWriteOpInterface>::value, 1267 "argument should be either a AffineReadOpInterface or a " 1268 "AffineWriteOpInterface"); 1269 1270 Operation *op = loadOrStoreOp.getOperation(); 1271 MemRefRegion region(op->getLoc()); 1272 if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr, 1273 /*addMemRefDimBounds=*/false))) 1274 return success(); 1275 1276 LLVM_DEBUG(llvm::dbgs() << "Memory region"); 1277 LLVM_DEBUG(region.getConstraints()->dump()); 1278 1279 bool outOfBounds = false; 1280 unsigned rank = loadOrStoreOp.getMemRefType().getRank(); 1281 1282 // For each dimension, check for out of bounds. 1283 for (unsigned r = 0; r < rank; r++) { 1284 FlatAffineValueConstraints ucst(*region.getConstraints()); 1285 1286 // Intersect memory region with constraint capturing out of bounds (both out 1287 // of upper and out of lower), and check if the constraint system is 1288 // feasible. If it is, there is at least one point out of bounds. 1289 SmallVector<int64_t, 4> ineq(rank + 1, 0); 1290 int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r); 1291 // TODO: handle dynamic dim sizes. 1292 if (dimSize == -1) 1293 continue; 1294 1295 // Check for overflow: d_i >= memref dim size. 1296 ucst.addBound(BoundType::LB, r, dimSize); 1297 outOfBounds = !ucst.isEmpty(); 1298 if (outOfBounds && emitError) { 1299 loadOrStoreOp.emitOpError() 1300 << "memref out of upper bound access along dimension #" << (r + 1); 1301 } 1302 1303 // Check for a negative index. 1304 FlatAffineValueConstraints lcst(*region.getConstraints()); 1305 std::fill(ineq.begin(), ineq.end(), 0); 1306 // d_i <= -1; 1307 lcst.addBound(BoundType::UB, r, -1); 1308 outOfBounds = !lcst.isEmpty(); 1309 if (outOfBounds && emitError) { 1310 loadOrStoreOp.emitOpError() 1311 << "memref out of lower bound access along dimension #" << (r + 1); 1312 } 1313 } 1314 return failure(outOfBounds); 1315 } 1316 1317 // Explicitly instantiate the template so that the compiler knows we need them! 1318 template LogicalResult 1319 mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp, 1320 bool emitError); 1321 template LogicalResult 1322 mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp, 1323 bool emitError); 1324 1325 // Returns in 'positions' the Block positions of 'op' in each ancestor 1326 // Block from the Block containing operation, stopping at 'limitBlock'. 1327 static void findInstPosition(Operation *op, Block *limitBlock, 1328 SmallVectorImpl<unsigned> *positions) { 1329 Block *block = op->getBlock(); 1330 while (block != limitBlock) { 1331 // FIXME: This algorithm is unnecessarily O(n) and should be improved to not 1332 // rely on linear scans. 1333 int instPosInBlock = std::distance(block->begin(), op->getIterator()); 1334 positions->push_back(instPosInBlock); 1335 op = block->getParentOp(); 1336 block = op->getBlock(); 1337 } 1338 std::reverse(positions->begin(), positions->end()); 1339 } 1340 1341 // Returns the Operation in a possibly nested set of Blocks, where the 1342 // position of the operation is represented by 'positions', which has a 1343 // Block position for each level of nesting. 1344 static Operation *getInstAtPosition(ArrayRef<unsigned> positions, 1345 unsigned level, Block *block) { 1346 unsigned i = 0; 1347 for (auto &op : *block) { 1348 if (i != positions[level]) { 1349 ++i; 1350 continue; 1351 } 1352 if (level == positions.size() - 1) 1353 return &op; 1354 if (auto childAffineForOp = dyn_cast<AffineForOp>(op)) 1355 return getInstAtPosition(positions, level + 1, 1356 childAffineForOp.getBody()); 1357 1358 for (auto ®ion : op.getRegions()) { 1359 for (auto &b : region) 1360 if (auto *ret = getInstAtPosition(positions, level + 1, &b)) 1361 return ret; 1362 } 1363 return nullptr; 1364 } 1365 return nullptr; 1366 } 1367 1368 // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'. 1369 static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs, 1370 FlatAffineValueConstraints *cst) { 1371 for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) { 1372 auto value = cst->getValue(i); 1373 if (ivs.count(value) == 0) { 1374 assert(isAffineForInductionVar(value)); 1375 auto loop = getForInductionVarOwner(value); 1376 if (failed(cst->addAffineForOpDomain(loop))) 1377 return failure(); 1378 } 1379 } 1380 return success(); 1381 } 1382 1383 /// Returns the innermost common loop depth for the set of operations in 'ops'. 1384 // TODO: Move this to LoopUtils. 1385 unsigned mlir::affine::getInnermostCommonLoopDepth( 1386 ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) { 1387 unsigned numOps = ops.size(); 1388 assert(numOps > 0 && "Expected at least one operation"); 1389 1390 std::vector<SmallVector<AffineForOp, 4>> loops(numOps); 1391 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max(); 1392 for (unsigned i = 0; i < numOps; ++i) { 1393 getAffineForIVs(*ops[i], &loops[i]); 1394 loopDepthLimit = 1395 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size())); 1396 } 1397 1398 unsigned loopDepth = 0; 1399 for (unsigned d = 0; d < loopDepthLimit; ++d) { 1400 unsigned i; 1401 for (i = 1; i < numOps; ++i) { 1402 if (loops[i - 1][d] != loops[i][d]) 1403 return loopDepth; 1404 } 1405 if (surroundingLoops) 1406 surroundingLoops->push_back(loops[i - 1][d]); 1407 ++loopDepth; 1408 } 1409 return loopDepth; 1410 } 1411 1412 /// Computes in 'sliceUnion' the union of all slice bounds computed at 1413 /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and 1414 /// then verifies if it is valid. Returns 'SliceComputationResult::Success' if 1415 /// union was computed correctly, an appropriate failure otherwise. 1416 SliceComputationResult 1417 mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, 1418 ArrayRef<Operation *> opsB, unsigned loopDepth, 1419 unsigned numCommonLoops, bool isBackwardSlice, 1420 ComputationSliceState *sliceUnion) { 1421 // Compute the union of slice bounds between all pairs in 'opsA' and 1422 // 'opsB' in 'sliceUnionCst'. 1423 FlatAffineValueConstraints sliceUnionCst; 1424 assert(sliceUnionCst.getNumDimAndSymbolVars() == 0); 1425 std::vector<std::pair<Operation *, Operation *>> dependentOpPairs; 1426 for (auto *i : opsA) { 1427 MemRefAccess srcAccess(i); 1428 for (auto *j : opsB) { 1429 MemRefAccess dstAccess(j); 1430 if (srcAccess.memref != dstAccess.memref) 1431 continue; 1432 // Check if 'loopDepth' exceeds nesting depth of src/dst ops. 1433 if ((!isBackwardSlice && loopDepth > getNestingDepth(i)) || 1434 (isBackwardSlice && loopDepth > getNestingDepth(j))) { 1435 LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n"); 1436 return SliceComputationResult::GenericFailure; 1437 } 1438 1439 bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) && 1440 isa<AffineReadOpInterface>(dstAccess.opInst); 1441 FlatAffineValueConstraints dependenceConstraints; 1442 // Check dependence between 'srcAccess' and 'dstAccess'. 1443 DependenceResult result = checkMemrefAccessDependence( 1444 srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1, 1445 &dependenceConstraints, /*dependenceComponents=*/nullptr, 1446 /*allowRAR=*/readReadAccesses); 1447 if (result.value == DependenceResult::Failure) { 1448 LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n"); 1449 return SliceComputationResult::GenericFailure; 1450 } 1451 if (result.value == DependenceResult::NoDependence) 1452 continue; 1453 dependentOpPairs.emplace_back(i, j); 1454 1455 // Compute slice bounds for 'srcAccess' and 'dstAccess'. 1456 ComputationSliceState tmpSliceState; 1457 mlir::affine::getComputationSliceState(i, j, &dependenceConstraints, 1458 loopDepth, isBackwardSlice, 1459 &tmpSliceState); 1460 1461 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) { 1462 // Initialize 'sliceUnionCst' with the bounds computed in previous step. 1463 if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { 1464 LLVM_DEBUG(llvm::dbgs() 1465 << "Unable to compute slice bound constraints\n"); 1466 return SliceComputationResult::GenericFailure; 1467 } 1468 assert(sliceUnionCst.getNumDimAndSymbolVars() > 0); 1469 continue; 1470 } 1471 1472 // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. 1473 FlatAffineValueConstraints tmpSliceCst; 1474 if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { 1475 LLVM_DEBUG(llvm::dbgs() 1476 << "Unable to compute slice bound constraints\n"); 1477 return SliceComputationResult::GenericFailure; 1478 } 1479 1480 // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed. 1481 if (!sliceUnionCst.areVarsAlignedWithOther(tmpSliceCst)) { 1482 1483 // Pre-constraint var alignment: record loop IVs used in each constraint 1484 // system. 1485 SmallPtrSet<Value, 8> sliceUnionIVs; 1486 for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k) 1487 sliceUnionIVs.insert(sliceUnionCst.getValue(k)); 1488 SmallPtrSet<Value, 8> tmpSliceIVs; 1489 for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k) 1490 tmpSliceIVs.insert(tmpSliceCst.getValue(k)); 1491 1492 sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, &tmpSliceCst); 1493 1494 // Post-constraint var alignment: add loop IV bounds missing after 1495 // var alignment to constraint systems. This can occur if one constraint 1496 // system uses an loop IV that is not used by the other. The call 1497 // to unionBoundingBox below expects constraints for each Loop IV, even 1498 // if they are the unsliced full loop bounds added here. 1499 if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst))) 1500 return SliceComputationResult::GenericFailure; 1501 if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst))) 1502 return SliceComputationResult::GenericFailure; 1503 } 1504 // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. 1505 if (sliceUnionCst.getNumLocalVars() > 0 || 1506 tmpSliceCst.getNumLocalVars() > 0 || 1507 failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { 1508 LLVM_DEBUG(llvm::dbgs() 1509 << "Unable to compute union bounding box of slice bounds\n"); 1510 return SliceComputationResult::GenericFailure; 1511 } 1512 } 1513 } 1514 1515 // Empty union. 1516 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) 1517 return SliceComputationResult::GenericFailure; 1518 1519 // Gather loops surrounding ops from loop nest where slice will be inserted. 1520 SmallVector<Operation *, 4> ops; 1521 for (auto &dep : dependentOpPairs) { 1522 ops.push_back(isBackwardSlice ? dep.second : dep.first); 1523 } 1524 SmallVector<AffineForOp, 4> surroundingLoops; 1525 unsigned innermostCommonLoopDepth = 1526 getInnermostCommonLoopDepth(ops, &surroundingLoops); 1527 if (loopDepth > innermostCommonLoopDepth) { 1528 LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n"); 1529 return SliceComputationResult::GenericFailure; 1530 } 1531 1532 // Store 'numSliceLoopIVs' before converting dst loop IVs to dims. 1533 unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars(); 1534 1535 // Convert any dst loop IVs which are symbol variables to dim variables. 1536 sliceUnionCst.convertLoopIVSymbolsToDims(); 1537 sliceUnion->clearBounds(); 1538 sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap()); 1539 sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap()); 1540 1541 // Get slice bounds from slice union constraints 'sliceUnionCst'. 1542 sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs, 1543 opsA[0]->getContext(), &sliceUnion->lbs, 1544 &sliceUnion->ubs); 1545 1546 // Add slice bound operands of union. 1547 SmallVector<Value, 4> sliceBoundOperands; 1548 sliceUnionCst.getValues(numSliceLoopIVs, 1549 sliceUnionCst.getNumDimAndSymbolVars(), 1550 &sliceBoundOperands); 1551 1552 // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'. 1553 sliceUnion->ivs.clear(); 1554 sliceUnionCst.getValues(0, numSliceLoopIVs, &sliceUnion->ivs); 1555 1556 // Set loop nest insertion point to block start at 'loopDepth'. 1557 sliceUnion->insertPoint = 1558 isBackwardSlice 1559 ? surroundingLoops[loopDepth - 1].getBody()->begin() 1560 : std::prev(surroundingLoops[loopDepth - 1].getBody()->end()); 1561 1562 // Give each bound its own copy of 'sliceBoundOperands' for subsequent 1563 // canonicalization. 1564 sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); 1565 sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); 1566 1567 // Check if the slice computed is valid. Return success only if it is verified 1568 // that the slice is valid, otherwise return appropriate failure status. 1569 std::optional<bool> isSliceValid = sliceUnion->isSliceValid(); 1570 if (!isSliceValid) { 1571 LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n"); 1572 return SliceComputationResult::GenericFailure; 1573 } 1574 if (!*isSliceValid) 1575 return SliceComputationResult::IncorrectSliceFailure; 1576 1577 return SliceComputationResult::Success; 1578 } 1579 1580 // TODO: extend this to handle multiple result maps. 1581 static std::optional<uint64_t> getConstDifference(AffineMap lbMap, 1582 AffineMap ubMap) { 1583 assert(lbMap.getNumResults() == 1 && "expected single result bound map"); 1584 assert(ubMap.getNumResults() == 1 && "expected single result bound map"); 1585 assert(lbMap.getNumDims() == ubMap.getNumDims()); 1586 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); 1587 AffineExpr lbExpr(lbMap.getResult(0)); 1588 AffineExpr ubExpr(ubMap.getResult(0)); 1589 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), 1590 lbMap.getNumSymbols()); 1591 auto cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr); 1592 if (!cExpr) 1593 return std::nullopt; 1594 return cExpr.getValue(); 1595 } 1596 1597 // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop 1598 // nest surrounding represented by slice loop bounds in 'slice'. Returns true 1599 // on success, false otherwise (if a non-constant trip count was encountered). 1600 // TODO: Make this work with non-unit step loops. 1601 bool mlir::affine::buildSliceTripCountMap( 1602 const ComputationSliceState &slice, 1603 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) { 1604 unsigned numSrcLoopIVs = slice.ivs.size(); 1605 // Populate map from AffineForOp -> trip count 1606 for (unsigned i = 0; i < numSrcLoopIVs; ++i) { 1607 AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]); 1608 auto *op = forOp.getOperation(); 1609 AffineMap lbMap = slice.lbs[i]; 1610 AffineMap ubMap = slice.ubs[i]; 1611 // If lower or upper bound maps are null or provide no results, it implies 1612 // that source loop was not at all sliced, and the entire loop will be a 1613 // part of the slice. 1614 if (!lbMap || lbMap.getNumResults() == 0 || !ubMap || 1615 ubMap.getNumResults() == 0) { 1616 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. 1617 if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) { 1618 (*tripCountMap)[op] = 1619 forOp.getConstantUpperBound() - forOp.getConstantLowerBound(); 1620 continue; 1621 } 1622 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); 1623 if (maybeConstTripCount.has_value()) { 1624 (*tripCountMap)[op] = *maybeConstTripCount; 1625 continue; 1626 } 1627 return false; 1628 } 1629 std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap); 1630 // Slice bounds are created with a constant ub - lb difference. 1631 if (!tripCount.has_value()) 1632 return false; 1633 (*tripCountMap)[op] = *tripCount; 1634 } 1635 return true; 1636 } 1637 1638 // Return the number of iterations in the given slice. 1639 uint64_t mlir::affine::getSliceIterationCount( 1640 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) { 1641 uint64_t iterCount = 1; 1642 for (const auto &count : sliceTripCountMap) { 1643 iterCount *= count.second; 1644 } 1645 return iterCount; 1646 } 1647 1648 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier"; 1649 // Computes slice bounds by projecting out any loop IVs from 1650 // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice 1651 // bounds in 'sliceState' which represent the one loop nest's IVs in terms of 1652 // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice'). 1653 void mlir::affine::getComputationSliceState( 1654 Operation *depSourceOp, Operation *depSinkOp, 1655 FlatAffineValueConstraints *dependenceConstraints, unsigned loopDepth, 1656 bool isBackwardSlice, ComputationSliceState *sliceState) { 1657 // Get loop nest surrounding src operation. 1658 SmallVector<AffineForOp, 4> srcLoopIVs; 1659 getAffineForIVs(*depSourceOp, &srcLoopIVs); 1660 unsigned numSrcLoopIVs = srcLoopIVs.size(); 1661 1662 // Get loop nest surrounding dst operation. 1663 SmallVector<AffineForOp, 4> dstLoopIVs; 1664 getAffineForIVs(*depSinkOp, &dstLoopIVs); 1665 unsigned numDstLoopIVs = dstLoopIVs.size(); 1666 1667 assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) || 1668 (isBackwardSlice && loopDepth <= numDstLoopIVs)); 1669 1670 // Project out dimensions other than those up to 'loopDepth'. 1671 unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth; 1672 unsigned num = 1673 isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth; 1674 dependenceConstraints->projectOut(pos, num); 1675 1676 // Add slice loop IV values to 'sliceState'. 1677 unsigned offset = isBackwardSlice ? 0 : loopDepth; 1678 unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs; 1679 dependenceConstraints->getValues(offset, offset + numSliceLoopIVs, 1680 &sliceState->ivs); 1681 1682 // Set up lower/upper bound affine maps for the slice. 1683 sliceState->lbs.resize(numSliceLoopIVs, AffineMap()); 1684 sliceState->ubs.resize(numSliceLoopIVs, AffineMap()); 1685 1686 // Get bounds for slice IVs in terms of other IVs, symbols, and constants. 1687 dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs, 1688 depSourceOp->getContext(), 1689 &sliceState->lbs, &sliceState->ubs); 1690 1691 // Set up bound operands for the slice's lower and upper bounds. 1692 SmallVector<Value, 4> sliceBoundOperands; 1693 unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolVars(); 1694 for (unsigned i = 0; i < numDimsAndSymbols; ++i) { 1695 if (i < offset || i >= offset + numSliceLoopIVs) { 1696 sliceBoundOperands.push_back(dependenceConstraints->getValue(i)); 1697 } 1698 } 1699 1700 // Give each bound its own copy of 'sliceBoundOperands' for subsequent 1701 // canonicalization. 1702 sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); 1703 sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); 1704 1705 // Set destination loop nest insertion point to block start at 'dstLoopDepth'. 1706 sliceState->insertPoint = 1707 isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin() 1708 : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end()); 1709 1710 llvm::SmallDenseSet<Value, 8> sequentialLoops; 1711 if (isa<AffineReadOpInterface>(depSourceOp) && 1712 isa<AffineReadOpInterface>(depSinkOp)) { 1713 // For read-read access pairs, clear any slice bounds on sequential loops. 1714 // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. 1715 getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0], 1716 &sequentialLoops); 1717 } 1718 auto getSliceLoop = [&](unsigned i) { 1719 return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i]; 1720 }; 1721 auto isInnermostInsertion = [&]() { 1722 return (isBackwardSlice ? loopDepth >= srcLoopIVs.size() 1723 : loopDepth >= dstLoopIVs.size()); 1724 }; 1725 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 1726 auto srcIsUnitSlice = [&]() { 1727 return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) && 1728 (getSliceIterationCount(sliceTripCountMap) == 1)); 1729 }; 1730 // Clear all sliced loop bounds beginning at the first sequential loop, or 1731 // first loop with a slice fusion barrier attribute.. 1732 1733 for (unsigned i = 0; i < numSliceLoopIVs; ++i) { 1734 Value iv = getSliceLoop(i).getInductionVar(); 1735 if (sequentialLoops.count(iv) == 0 && 1736 getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr) 1737 continue; 1738 // Skip reset of bounds of reduction loop inserted in the destination loop 1739 // that meets the following conditions: 1740 // 1. Slice is single trip count. 1741 // 2. Loop bounds of the source and destination match. 1742 // 3. Is being inserted at the innermost insertion point. 1743 std::optional<bool> isMaximal = sliceState->isMaximal(); 1744 if (isLoopParallelAndContainsReduction(getSliceLoop(i)) && 1745 isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal) 1746 continue; 1747 for (unsigned j = i; j < numSliceLoopIVs; ++j) { 1748 sliceState->lbs[j] = AffineMap(); 1749 sliceState->ubs[j] = AffineMap(); 1750 } 1751 break; 1752 } 1753 } 1754 1755 /// Creates a computation slice of the loop nest surrounding 'srcOpInst', 1756 /// updates the slice loop bounds with any non-null bound maps specified in 1757 /// 'sliceState', and inserts this slice into the loop nest surrounding 1758 /// 'dstOpInst' at loop depth 'dstLoopDepth'. 1759 // TODO: extend the slicing utility to compute slices that 1760 // aren't necessarily a one-to-one relation b/w the source and destination. The 1761 // relation between the source and destination could be many-to-many in general. 1762 // TODO: the slice computation is incorrect in the cases 1763 // where the dependence from the source to the destination does not cover the 1764 // entire destination index set. Subtract out the dependent destination 1765 // iterations from destination index set and check for emptiness --- this is one 1766 // solution. 1767 AffineForOp mlir::affine::insertBackwardComputationSlice( 1768 Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth, 1769 ComputationSliceState *sliceState) { 1770 // Get loop nest surrounding src operation. 1771 SmallVector<AffineForOp, 4> srcLoopIVs; 1772 getAffineForIVs(*srcOpInst, &srcLoopIVs); 1773 unsigned numSrcLoopIVs = srcLoopIVs.size(); 1774 1775 // Get loop nest surrounding dst operation. 1776 SmallVector<AffineForOp, 4> dstLoopIVs; 1777 getAffineForIVs(*dstOpInst, &dstLoopIVs); 1778 unsigned dstLoopIVsSize = dstLoopIVs.size(); 1779 if (dstLoopDepth > dstLoopIVsSize) { 1780 dstOpInst->emitError("invalid destination loop depth"); 1781 return AffineForOp(); 1782 } 1783 1784 // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'. 1785 SmallVector<unsigned, 4> positions; 1786 // TODO: This code is incorrect since srcLoopIVs can be 0-d. 1787 findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions); 1788 1789 // Clone src loop nest and insert it a the beginning of the operation block 1790 // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. 1791 auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; 1792 OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); 1793 auto sliceLoopNest = 1794 cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation())); 1795 1796 Operation *sliceInst = 1797 getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody()); 1798 // Get loop nest surrounding 'sliceInst'. 1799 SmallVector<AffineForOp, 4> sliceSurroundingLoops; 1800 getAffineForIVs(*sliceInst, &sliceSurroundingLoops); 1801 1802 // Sanity check. 1803 unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size(); 1804 (void)sliceSurroundingLoopsSize; 1805 assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize); 1806 unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs; 1807 (void)sliceLoopLimit; 1808 assert(sliceLoopLimit >= sliceSurroundingLoopsSize); 1809 1810 // Update loop bounds for loops in 'sliceLoopNest'. 1811 for (unsigned i = 0; i < numSrcLoopIVs; ++i) { 1812 auto forOp = sliceSurroundingLoops[dstLoopDepth + i]; 1813 if (AffineMap lbMap = sliceState->lbs[i]) 1814 forOp.setLowerBound(sliceState->lbOperands[i], lbMap); 1815 if (AffineMap ubMap = sliceState->ubs[i]) 1816 forOp.setUpperBound(sliceState->ubOperands[i], ubMap); 1817 } 1818 return sliceLoopNest; 1819 } 1820 1821 // Constructs MemRefAccess populating it with the memref, its indices and 1822 // opinst from 'loadOrStoreOpInst'. 1823 MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { 1824 if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) { 1825 memref = loadOp.getMemRef(); 1826 opInst = loadOrStoreOpInst; 1827 llvm::append_range(indices, loadOp.getMapOperands()); 1828 } else { 1829 assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) && 1830 "Affine read/write op expected"); 1831 auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst); 1832 opInst = loadOrStoreOpInst; 1833 memref = storeOp.getMemRef(); 1834 llvm::append_range(indices, storeOp.getMapOperands()); 1835 } 1836 } 1837 1838 unsigned MemRefAccess::getRank() const { 1839 return cast<MemRefType>(memref.getType()).getRank(); 1840 } 1841 1842 bool MemRefAccess::isStore() const { 1843 return isa<AffineWriteOpInterface>(opInst); 1844 } 1845 1846 /// Returns the nesting depth of this statement, i.e., the number of loops 1847 /// surrounding this statement. 1848 unsigned mlir::affine::getNestingDepth(Operation *op) { 1849 Operation *currOp = op; 1850 unsigned depth = 0; 1851 while ((currOp = currOp->getParentOp())) { 1852 if (isa<AffineForOp>(currOp)) 1853 depth++; 1854 } 1855 return depth; 1856 } 1857 1858 /// Equal if both affine accesses are provably equivalent (at compile 1859 /// time) when considering the memref, the affine maps and their respective 1860 /// operands. The equality of access functions + operands is checked by 1861 /// subtracting fully composed value maps, and then simplifying the difference 1862 /// using the expression flattener. 1863 /// TODO: this does not account for aliasing of memrefs. 1864 bool MemRefAccess::operator==(const MemRefAccess &rhs) const { 1865 if (memref != rhs.memref) 1866 return false; 1867 1868 AffineValueMap diff, thisMap, rhsMap; 1869 getAccessMap(&thisMap); 1870 rhs.getAccessMap(&rhsMap); 1871 AffineValueMap::difference(thisMap, rhsMap, &diff); 1872 return llvm::all_of(diff.getAffineMap().getResults(), 1873 [](AffineExpr e) { return e == 0; }); 1874 } 1875 1876 void mlir::affine::getAffineIVs(Operation &op, SmallVectorImpl<Value> &ivs) { 1877 auto *currOp = op.getParentOp(); 1878 AffineForOp currAffineForOp; 1879 // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel 1880 // operation while skipping over 'affine.if' operations. 1881 while (currOp) { 1882 if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp)) 1883 ivs.push_back(currAffineForOp.getInductionVar()); 1884 else if (auto parOp = dyn_cast<AffineParallelOp>(currOp)) 1885 llvm::append_range(ivs, parOp.getIVs()); 1886 currOp = currOp->getParentOp(); 1887 } 1888 std::reverse(ivs.begin(), ivs.end()); 1889 } 1890 1891 /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', 1892 /// where each lists loops from outer-most to inner-most in loop nest. 1893 unsigned mlir::affine::getNumCommonSurroundingLoops(Operation &a, 1894 Operation &b) { 1895 SmallVector<Value, 4> loopsA, loopsB; 1896 getAffineIVs(a, loopsA); 1897 getAffineIVs(b, loopsB); 1898 1899 unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); 1900 unsigned numCommonLoops = 0; 1901 for (unsigned i = 0; i < minNumLoops; ++i) { 1902 if (loopsA[i] != loopsB[i]) 1903 break; 1904 ++numCommonLoops; 1905 } 1906 return numCommonLoops; 1907 } 1908 1909 static std::optional<int64_t> getMemoryFootprintBytes(Block &block, 1910 Block::iterator start, 1911 Block::iterator end, 1912 int memorySpace) { 1913 SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions; 1914 1915 // Walk this 'affine.for' operation to gather all memory regions. 1916 auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult { 1917 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) { 1918 // Neither load nor a store op. 1919 return WalkResult::advance(); 1920 } 1921 1922 // Compute the memref region symbolic in any IVs enclosing this block. 1923 auto region = std::make_unique<MemRefRegion>(opInst->getLoc()); 1924 if (failed( 1925 region->compute(opInst, 1926 /*loopDepth=*/getNestingDepth(&*block.begin())))) { 1927 return opInst->emitError("error obtaining memory region\n"); 1928 } 1929 1930 auto [it, inserted] = regions.try_emplace(region->memref); 1931 if (inserted) { 1932 it->second = std::move(region); 1933 } else if (failed(it->second->unionBoundingBox(*region))) { 1934 return opInst->emitWarning( 1935 "getMemoryFootprintBytes: unable to perform a union on a memory " 1936 "region"); 1937 } 1938 return WalkResult::advance(); 1939 }); 1940 if (result.wasInterrupted()) 1941 return std::nullopt; 1942 1943 int64_t totalSizeInBytes = 0; 1944 for (const auto ®ion : regions) { 1945 std::optional<int64_t> size = region.second->getRegionSize(); 1946 if (!size.has_value()) 1947 return std::nullopt; 1948 totalSizeInBytes += *size; 1949 } 1950 return totalSizeInBytes; 1951 } 1952 1953 std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp, 1954 int memorySpace) { 1955 auto *forInst = forOp.getOperation(); 1956 return ::getMemoryFootprintBytes( 1957 *forInst->getBlock(), Block::iterator(forInst), 1958 std::next(Block::iterator(forInst)), memorySpace); 1959 } 1960 1961 /// Returns whether a loop is parallel and contains a reduction loop. 1962 bool mlir::affine::isLoopParallelAndContainsReduction(AffineForOp forOp) { 1963 SmallVector<LoopReduction> reductions; 1964 if (!isLoopParallel(forOp, &reductions)) 1965 return false; 1966 return !reductions.empty(); 1967 } 1968 1969 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted 1970 /// at 'forOp'. 1971 void mlir::affine::getSequentialLoops( 1972 AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) { 1973 forOp->walk([&](Operation *op) { 1974 if (auto innerFor = dyn_cast<AffineForOp>(op)) 1975 if (!isLoopParallel(innerFor)) 1976 sequentialLoops->insert(innerFor.getInductionVar()); 1977 }); 1978 } 1979 1980 IntegerSet mlir::affine::simplifyIntegerSet(IntegerSet set) { 1981 FlatAffineValueConstraints fac(set); 1982 if (fac.isEmpty()) 1983 return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(), 1984 set.getContext()); 1985 fac.removeTrivialRedundancy(); 1986 1987 auto simplifiedSet = fac.getAsIntegerSet(set.getContext()); 1988 assert(simplifiedSet && "guaranteed to succeed while roundtripping"); 1989 return simplifiedSet; 1990 } 1991 1992 static void unpackOptionalValues(ArrayRef<std::optional<Value>> source, 1993 SmallVector<Value> &target) { 1994 target = 1995 llvm::to_vector<4>(llvm::map_range(source, [](std::optional<Value> val) { 1996 return val.has_value() ? *val : Value(); 1997 })); 1998 } 1999 2000 /// Bound an identifier `pos` in a given FlatAffineValueConstraints with 2001 /// constraints drawn from an affine map. Before adding the constraint, the 2002 /// dimensions/symbols of the affine map are aligned with `constraints`. 2003 /// `operands` are the SSA Value operands used with the affine map. 2004 /// Note: This function adds a new symbol column to the `constraints` for each 2005 /// dimension/symbol that exists in the affine map but not in `constraints`. 2006 static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints, 2007 BoundType type, unsigned pos, 2008 AffineMap map, ValueRange operands) { 2009 SmallVector<Value> dims, syms, newSyms; 2010 unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims); 2011 unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms); 2012 2013 AffineMap alignedMap = 2014 alignAffineMapWithValues(map, operands, dims, syms, &newSyms); 2015 for (unsigned i = syms.size(); i < newSyms.size(); ++i) 2016 constraints.appendSymbolVar(newSyms[i]); 2017 return constraints.addBound(type, pos, alignedMap); 2018 } 2019 2020 /// Add `val` to each result of `map`. 2021 static AffineMap addConstToResults(AffineMap map, int64_t val) { 2022 SmallVector<AffineExpr> newResults; 2023 for (AffineExpr r : map.getResults()) 2024 newResults.push_back(r + val); 2025 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults, 2026 map.getContext()); 2027 } 2028 2029 // Attempt to simplify the given min/max operation by proving that its value is 2030 // bounded by the same lower and upper bound. 2031 // 2032 // Bounds are computed by FlatAffineValueConstraints. Invariants required for 2033 // finding/proving bounds should be supplied via `constraints`. 2034 // 2035 // 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`). 2036 // 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in 2037 // case of `!isMin`) and bind it to `opBound`. SSA values that are used in 2038 // `op` but are not part of `constraints`, are added as extra symbols. 2039 // 3. For each result of `op`: Add result as a dimension `r_i`. Prove that: 2040 // * If `isMin`: r_i >= opBound 2041 // * If `isMax`: r_i <= opBound 2042 // If this is the case, ub(op) == lb(op). 2043 // 4. Replace `op` with `opBound`. 2044 // 2045 // In summary, the following constraints are added throughout this function. 2046 // Note: `invar` are dimensions added by the caller to express the invariants. 2047 // (Showing only the case where `isMin`.) 2048 // 2049 // invar | op | opBound | r_i | extra syms... | const | eq/ineq 2050 // ------+-------+---------+-----+---------------+-------+------------------- 2051 // (various eq./ineq. constraining `invar`, added by the caller) 2052 // ... | 0 | 0 | 0 | 0 | ... | ... 2053 // ------+-------+---------+-----+---------------+-------+------------------- 2054 // (various ineq. constraining `op` in terms of `op` operands (`invar` and 2055 // extra `op` operands "extra syms" that are not in `invar`)). 2056 // ... | -1 | 0 | 0 | ... | ... | >= 0 2057 // ------+-------+---------+-----+---------------+-------+------------------- 2058 // (set `opBound` to `op` upper bound in terms of `invar` and "extra syms") 2059 // ... | 0 | -1 | 0 | ... | ... | = 0 2060 // ------+-------+---------+-----+---------------+-------+------------------- 2061 // (for each `op` map result r_i: set r_i to corresponding map result, 2062 // prove that r_i >= minOpUb via contradiction) 2063 // ... | 0 | 0 | -1 | ... | ... | = 0 2064 // 0 | 0 | 1 | -1 | 0 | -1 | >= 0 2065 // 2066 FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp( 2067 Operation *op, FlatAffineValueConstraints constraints) { 2068 bool isMin = isa<AffineMinOp>(op); 2069 assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp"); 2070 MLIRContext *ctx = op->getContext(); 2071 Builder builder(ctx); 2072 AffineMap map = 2073 isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap(); 2074 ValueRange operands = op->getOperands(); 2075 unsigned numResults = map.getNumResults(); 2076 2077 // Add a few extra dimensions. 2078 unsigned dimOp = constraints.appendDimVar(); // `op` 2079 unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound 2080 unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults); 2081 2082 // Add an inequality for each result expr_i of map: 2083 // isMin: op <= expr_i, !isMin: op >= expr_i 2084 auto boundType = isMin ? BoundType::UB : BoundType::LB; 2085 // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.) 2086 AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map; 2087 if (failed( 2088 alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands))) 2089 return failure(); 2090 2091 // Try to compute a lower/upper bound for op, expressed in terms of the other 2092 // `dims` and extra symbols. 2093 SmallVector<AffineMap> opLb(1), opUb(1); 2094 constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb); 2095 AffineMap sliceBound = isMin ? opUb[0] : opLb[0]; 2096 // TODO: `getSliceBounds` may return multiple bounds at the moment. This is 2097 // a TODO of `getSliceBounds` and not handled here. 2098 if (!sliceBound || sliceBound.getNumResults() != 1) 2099 return failure(); // No or multiple bounds found. 2100 // Recover the inclusive UB in the case of an `affine.min`. 2101 AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound; 2102 2103 // Add an equality: Set dimOpBound to computed bound. 2104 // Add back dimension for op. (Was removed by `getSliceBounds`.) 2105 AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp); 2106 if (failed(constraints.addBound(BoundType::EQ, dimOpBound, alignedBoundMap))) 2107 return failure(); 2108 2109 // If the constraint system is empty, there is an inconsistency. (E.g., this 2110 // can happen if loop lb > ub.) 2111 if (constraints.isEmpty()) 2112 return failure(); 2113 2114 // In the case of `isMin` (`!isMin` is inversed): 2115 // Prove that each result of `map` has a lower bound that is equal to (or 2116 // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op` 2117 // can be replaced with the bound. I.e., prove that for each result 2118 // expr_i (represented by dimension r_i): 2119 // 2120 // r_i >= opBound 2121 // 2122 // To prove this inequality, add its negation to the constraint set and prove 2123 // that the constraint set is empty. 2124 for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) { 2125 FlatAffineValueConstraints newConstr(constraints); 2126 2127 // Add an equality: r_i = expr_i 2128 // Note: These equalities could have been added earlier and used to express 2129 // minOp <= expr_i. However, then we run the risk that `getSliceBounds` 2130 // computes minOpUb in terms of r_i dims, which is not desired. 2131 if (failed(alignAndAddBound(newConstr, BoundType::EQ, i, 2132 map.getSubMap({i - resultDimStart}), operands))) 2133 return failure(); 2134 2135 // If `isMin`: Add inequality: r_i < opBound 2136 // equiv.: opBound - r_i - 1 >= 0 2137 // If `!isMin`: Add inequality: r_i > opBound 2138 // equiv.: -opBound + r_i - 1 >= 0 2139 SmallVector<int64_t> ineq(newConstr.getNumCols(), 0); 2140 ineq[dimOpBound] = isMin ? 1 : -1; 2141 ineq[i] = isMin ? -1 : 1; 2142 ineq[newConstr.getNumCols() - 1] = -1; 2143 newConstr.addInequality(ineq); 2144 if (!newConstr.isEmpty()) 2145 return failure(); 2146 } 2147 2148 // Lower and upper bound of `op` are equal. Replace `minOp` with its bound. 2149 AffineMap newMap = alignedBoundMap; 2150 SmallVector<Value> newOperands; 2151 unpackOptionalValues(constraints.getMaybeValues(), newOperands); 2152 // If dims/symbols have known constant values, use those in order to simplify 2153 // the affine map further. 2154 for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) { 2155 // Skip unused operands and operands that are already constants. 2156 if (!newOperands[i] || getConstantIntValue(newOperands[i])) 2157 continue; 2158 if (auto bound = constraints.getConstantBound64(BoundType::EQ, i)) { 2159 AffineExpr expr = 2160 i < newMap.getNumDims() 2161 ? builder.getAffineDimExpr(i) 2162 : builder.getAffineSymbolExpr(i - newMap.getNumDims()); 2163 newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound), 2164 newMap.getNumDims(), newMap.getNumSymbols()); 2165 } 2166 } 2167 affine::canonicalizeMapAndOperands(&newMap, &newOperands); 2168 return AffineValueMap(newMap, newOperands); 2169 } 2170