xref: /llvm-project/mlir/lib/Dialect/Affine/Analysis/Utils.cpp (revision 31e8c539e0fdc4b251871c9126f7bc28fc8fb74b)
1 //===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/Analysis/Utils.h"
15 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/IntegerSet.h"
23 #include "mlir/Interfaces/CallInterfaces.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <optional>
29 
30 #define DEBUG_TYPE "analysis-utils"
31 
32 using namespace mlir;
33 using namespace affine;
34 using namespace presburger;
35 
36 using llvm::SmallDenseMap;
37 
38 using Node = MemRefDependenceGraph::Node;
39 
40 // LoopNestStateCollector walks loop nests and collects load and store
41 // operations, and whether or not a region holding op other than ForOp and IfOp
42 // was encountered in the loop nest.
43 void LoopNestStateCollector::collect(Operation *opToWalk) {
44   opToWalk->walk([&](Operation *op) {
45     if (isa<AffineForOp>(op))
46       forOps.push_back(cast<AffineForOp>(op));
47     else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
48       hasNonAffineRegionOp = true;
49     else if (isa<AffineReadOpInterface>(op))
50       loadOpInsts.push_back(op);
51     else if (isa<AffineWriteOpInterface>(op))
52       storeOpInsts.push_back(op);
53   });
54 }
55 
56 // Returns the load op count for 'memref'.
57 unsigned Node::getLoadOpCount(Value memref) const {
58   unsigned loadOpCount = 0;
59   for (Operation *loadOp : loads) {
60     if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
61       ++loadOpCount;
62   }
63   return loadOpCount;
64 }
65 
66 // Returns the store op count for 'memref'.
67 unsigned Node::getStoreOpCount(Value memref) const {
68   unsigned storeOpCount = 0;
69   for (Operation *storeOp : stores) {
70     if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
71       ++storeOpCount;
72   }
73   return storeOpCount;
74 }
75 
76 // Returns all store ops in 'storeOps' which access 'memref'.
77 void Node::getStoreOpsForMemref(Value memref,
78                                 SmallVectorImpl<Operation *> *storeOps) const {
79   for (Operation *storeOp : stores) {
80     if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
81       storeOps->push_back(storeOp);
82   }
83 }
84 
85 // Returns all load ops in 'loadOps' which access 'memref'.
86 void Node::getLoadOpsForMemref(Value memref,
87                                SmallVectorImpl<Operation *> *loadOps) const {
88   for (Operation *loadOp : loads) {
89     if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
90       loadOps->push_back(loadOp);
91   }
92 }
93 
94 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
95 // has at least one load and store operation.
96 void Node::getLoadAndStoreMemrefSet(
97     DenseSet<Value> *loadAndStoreMemrefSet) const {
98   llvm::SmallDenseSet<Value, 2> loadMemrefs;
99   for (Operation *loadOp : loads) {
100     loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
101   }
102   for (Operation *storeOp : stores) {
103     auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
104     if (loadMemrefs.count(memref) > 0)
105       loadAndStoreMemrefSet->insert(memref);
106   }
107 }
108 
109 // Initializes the data dependence graph by walking operations in `block`.
110 // Assigns each node in the graph a node id based on program order in 'f'.
111 bool MemRefDependenceGraph::init() {
112   LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
113   // Map from a memref to the set of ids of the nodes that have ops accessing
114   // the memref.
115   DenseMap<Value, SetVector<unsigned>> memrefAccesses;
116 
117   DenseMap<Operation *, unsigned> forToNodeMap;
118   for (Operation &op : block) {
119     if (dyn_cast<AffineForOp>(op)) {
120       // Create graph node 'id' to represent top-level 'forOp' and record
121       // all loads and store accesses it contains.
122       LoopNestStateCollector collector;
123       collector.collect(&op);
124       // Return false if a region holding op other than 'affine.for' and
125       // 'affine.if' was found (not currently supported).
126       if (collector.hasNonAffineRegionOp)
127         return false;
128       Node node(nextNodeId++, &op);
129       for (auto *opInst : collector.loadOpInsts) {
130         node.loads.push_back(opInst);
131         auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
132         memrefAccesses[memref].insert(node.id);
133       }
134       for (auto *opInst : collector.storeOpInsts) {
135         node.stores.push_back(opInst);
136         auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
137         memrefAccesses[memref].insert(node.id);
138       }
139       forToNodeMap[&op] = node.id;
140       nodes.insert({node.id, node});
141     } else if (dyn_cast<AffineReadOpInterface>(op)) {
142       // Create graph node for top-level load op.
143       Node node(nextNodeId++, &op);
144       node.loads.push_back(&op);
145       auto memref = cast<AffineReadOpInterface>(op).getMemRef();
146       memrefAccesses[memref].insert(node.id);
147       nodes.insert({node.id, node});
148     } else if (dyn_cast<AffineWriteOpInterface>(op)) {
149       // Create graph node for top-level store op.
150       Node node(nextNodeId++, &op);
151       node.stores.push_back(&op);
152       auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
153       memrefAccesses[memref].insert(node.id);
154       nodes.insert({node.id, node});
155     } else if (op.getNumResults() > 0 && !op.use_empty()) {
156       // Create graph node for top-level producer of SSA values, which
157       // could be used by loop nest nodes.
158       Node node(nextNodeId++, &op);
159       nodes.insert({node.id, node});
160     } else if (!isMemoryEffectFree(&op) &&
161                (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
162       // Create graph node for top-level op unless it is known to be
163       // memory-effect free. This covers all unknown/unregistered ops,
164       // non-affine ops with memory effects, and region-holding ops with a
165       // well-defined control flow. During the fusion validity checks, we look
166       // for non-affine ops on the path from source to destination, at which
167       // point we check which memrefs if any are used in the region.
168       Node node(nextNodeId++, &op);
169       nodes.insert({node.id, node});
170     } else if (op.getNumRegions() != 0) {
171       // Return false if non-handled/unknown region-holding ops are found. We
172       // won't know what such ops do or what its regions mean; for e.g., it may
173       // not be an imperative op.
174       LLVM_DEBUG(llvm::dbgs()
175                  << "MDG init failed; unknown region-holding op found!\n");
176       return false;
177     }
178   }
179 
180   for (auto &idAndNode : nodes) {
181     LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
182                             << *(idAndNode.second.op) << "\n");
183     (void)idAndNode;
184   }
185 
186   // Add dependence edges between nodes which produce SSA values and their
187   // users. Load ops can be considered as the ones producing SSA values.
188   for (auto &idAndNode : nodes) {
189     const Node &node = idAndNode.second;
190     // Stores don't define SSA values, skip them.
191     if (!node.stores.empty())
192       continue;
193     Operation *opInst = node.op;
194     for (Value value : opInst->getResults()) {
195       for (Operation *user : value.getUsers()) {
196         // Ignore users outside of the block.
197         if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
198             &block)
199           continue;
200         SmallVector<AffineForOp, 4> loops;
201         getAffineForIVs(*user, &loops);
202         // Find the surrounding affine.for nested immediately within the
203         // block.
204         auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
205           return loop->getBlock() == &block;
206         });
207         if (it == loops.end())
208           continue;
209         assert(forToNodeMap.count(*it) > 0 && "missing mapping");
210         unsigned userLoopNestId = forToNodeMap[*it];
211         addEdge(node.id, userLoopNestId, value);
212       }
213     }
214   }
215 
216   // Walk memref access lists and add graph edges between dependent nodes.
217   for (auto &memrefAndList : memrefAccesses) {
218     unsigned n = memrefAndList.second.size();
219     for (unsigned i = 0; i < n; ++i) {
220       unsigned srcId = memrefAndList.second[i];
221       bool srcHasStore =
222           getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
223       for (unsigned j = i + 1; j < n; ++j) {
224         unsigned dstId = memrefAndList.second[j];
225         bool dstHasStore =
226             getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
227         if (srcHasStore || dstHasStore)
228           addEdge(srcId, dstId, memrefAndList.first);
229       }
230     }
231   }
232   return true;
233 }
234 
235 // Returns the graph node for 'id'.
236 Node *MemRefDependenceGraph::getNode(unsigned id) {
237   auto it = nodes.find(id);
238   assert(it != nodes.end());
239   return &it->second;
240 }
241 
242 // Returns the graph node for 'forOp'.
243 Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
244   for (auto &idAndNode : nodes)
245     if (idAndNode.second.op == forOp)
246       return &idAndNode.second;
247   return nullptr;
248 }
249 
250 // Adds a node with 'op' to the graph and returns its unique identifier.
251 unsigned MemRefDependenceGraph::addNode(Operation *op) {
252   Node node(nextNodeId++, op);
253   nodes.insert({node.id, node});
254   return node.id;
255 }
256 
257 // Remove node 'id' (and its associated edges) from graph.
258 void MemRefDependenceGraph::removeNode(unsigned id) {
259   // Remove each edge in 'inEdges[id]'.
260   if (inEdges.count(id) > 0) {
261     SmallVector<Edge, 2> oldInEdges = inEdges[id];
262     for (auto &inEdge : oldInEdges) {
263       removeEdge(inEdge.id, id, inEdge.value);
264     }
265   }
266   // Remove each edge in 'outEdges[id]'.
267   if (outEdges.count(id) > 0) {
268     SmallVector<Edge, 2> oldOutEdges = outEdges[id];
269     for (auto &outEdge : oldOutEdges) {
270       removeEdge(id, outEdge.id, outEdge.value);
271     }
272   }
273   // Erase remaining node state.
274   inEdges.erase(id);
275   outEdges.erase(id);
276   nodes.erase(id);
277 }
278 
279 // Returns true if node 'id' writes to any memref which escapes (or is an
280 // argument to) the block. Returns false otherwise.
281 bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
282   Node *node = getNode(id);
283   for (auto *storeOpInst : node->stores) {
284     auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
285     auto *op = memref.getDefiningOp();
286     // Return true if 'memref' is a block argument.
287     if (!op)
288       return true;
289     // Return true if any use of 'memref' does not deference it in an affine
290     // way.
291     for (auto *user : memref.getUsers())
292       if (!isa<AffineMapAccessInterface>(*user))
293         return true;
294   }
295   return false;
296 }
297 
298 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
299 // is for 'value' if non-null, or for any value otherwise. Returns false
300 // otherwise.
301 bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
302                                     Value value) {
303   if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
304     return false;
305   }
306   bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
307     return edge.id == dstId && (!value || edge.value == value);
308   });
309   bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
310     return edge.id == srcId && (!value || edge.value == value);
311   });
312   return hasOutEdge && hasInEdge;
313 }
314 
315 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
316 void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
317                                     Value value) {
318   if (!hasEdge(srcId, dstId, value)) {
319     outEdges[srcId].push_back({dstId, value});
320     inEdges[dstId].push_back({srcId, value});
321     if (isa<MemRefType>(value.getType()))
322       memrefEdgeCount[value]++;
323   }
324 }
325 
326 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
327 void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
328                                        Value value) {
329   assert(inEdges.count(dstId) > 0);
330   assert(outEdges.count(srcId) > 0);
331   if (isa<MemRefType>(value.getType())) {
332     assert(memrefEdgeCount.count(value) > 0);
333     memrefEdgeCount[value]--;
334   }
335   // Remove 'srcId' from 'inEdges[dstId]'.
336   for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
337     if ((*it).id == srcId && (*it).value == value) {
338       inEdges[dstId].erase(it);
339       break;
340     }
341   }
342   // Remove 'dstId' from 'outEdges[srcId]'.
343   for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
344     if ((*it).id == dstId && (*it).value == value) {
345       outEdges[srcId].erase(it);
346       break;
347     }
348   }
349 }
350 
351 // Returns true if there is a path in the dependence graph from node 'srcId'
352 // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
353 // operations that the edges connected are expected to be from the same block.
354 bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
355   // Worklist state is: <node-id, next-output-edge-index-to-visit>
356   SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
357   worklist.push_back({srcId, 0});
358   Operation *dstOp = getNode(dstId)->op;
359   // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
360   while (!worklist.empty()) {
361     auto &idAndIndex = worklist.back();
362     // Return true if we have reached 'dstId'.
363     if (idAndIndex.first == dstId)
364       return true;
365     // Pop and continue if node has no out edges, or if all out edges have
366     // already been visited.
367     if (outEdges.count(idAndIndex.first) == 0 ||
368         idAndIndex.second == outEdges[idAndIndex.first].size()) {
369       worklist.pop_back();
370       continue;
371     }
372     // Get graph edge to traverse.
373     Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
374     // Increment next output edge index for 'idAndIndex'.
375     ++idAndIndex.second;
376     // Add node at 'edge.id' to the worklist. We don't need to consider
377     // nodes that are "after" dstId in the containing block; one can't have a
378     // path to `dstId` from any of those nodes.
379     bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
380     if (!afterDst && edge.id != idAndIndex.first)
381       worklist.push_back({edge.id, 0});
382   }
383   return false;
384 }
385 
386 // Returns the input edge count for node 'id' and 'memref' from src nodes
387 // which access 'memref' with a store operation.
388 unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
389                                                           Value memref) {
390   unsigned inEdgeCount = 0;
391   if (inEdges.count(id) > 0)
392     for (auto &inEdge : inEdges[id])
393       if (inEdge.value == memref) {
394         Node *srcNode = getNode(inEdge.id);
395         // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
396         if (srcNode->getStoreOpCount(memref) > 0)
397           ++inEdgeCount;
398       }
399   return inEdgeCount;
400 }
401 
402 // Returns the output edge count for node 'id' and 'memref' (if non-null),
403 // otherwise returns the total output edge count from node 'id'.
404 unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
405   unsigned outEdgeCount = 0;
406   if (outEdges.count(id) > 0)
407     for (auto &outEdge : outEdges[id])
408       if (!memref || outEdge.value == memref)
409         ++outEdgeCount;
410   return outEdgeCount;
411 }
412 
413 /// Return all nodes which define SSA values used in node 'id'.
414 void MemRefDependenceGraph::gatherDefiningNodes(
415     unsigned id, DenseSet<unsigned> &definingNodes) {
416   for (MemRefDependenceGraph::Edge edge : inEdges[id])
417     // By definition of edge, if the edge value is a non-memref value,
418     // then the dependence is between a graph node which defines an SSA value
419     // and another graph node which uses the SSA value.
420     if (!isa<MemRefType>(edge.value.getType()))
421       definingNodes.insert(edge.id);
422 }
423 
424 // Computes and returns an insertion point operation, before which the
425 // the fused <srcId, dstId> loop nest can be inserted while preserving
426 // dependences. Returns nullptr if no such insertion point is found.
427 Operation *
428 MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
429                                                       unsigned dstId) {
430   if (outEdges.count(srcId) == 0)
431     return getNode(dstId)->op;
432 
433   // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
434   DenseSet<unsigned> definingNodes;
435   gatherDefiningNodes(dstId, definingNodes);
436   if (llvm::any_of(definingNodes,
437                    [&](unsigned id) { return hasDependencePath(srcId, id); })) {
438     LLVM_DEBUG(llvm::dbgs()
439                << "Can't fuse: a defining op with a user in the dst "
440                   "loop has dependence from the src loop\n");
441     return nullptr;
442   }
443 
444   // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
445   SmallPtrSet<Operation *, 2> srcDepInsts;
446   for (auto &outEdge : outEdges[srcId])
447     if (outEdge.id != dstId)
448       srcDepInsts.insert(getNode(outEdge.id)->op);
449 
450   // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
451   SmallPtrSet<Operation *, 2> dstDepInsts;
452   for (auto &inEdge : inEdges[dstId])
453     if (inEdge.id != srcId)
454       dstDepInsts.insert(getNode(inEdge.id)->op);
455 
456   Operation *srcNodeInst = getNode(srcId)->op;
457   Operation *dstNodeInst = getNode(dstId)->op;
458 
459   // Computing insertion point:
460   // *) Walk all operation positions in Block operation list in the
461   //    range (src, dst). For each operation 'op' visited in this search:
462   //   *) Store in 'firstSrcDepPos' the first position where 'op' has a
463   //      dependence edge from 'srcNode'.
464   //   *) Store in 'lastDstDepPost' the last position where 'op' has a
465   //      dependence edge to 'dstNode'.
466   // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
467   //    operation insertion point (or return null pointer if no such
468   //    insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
469   SmallVector<Operation *, 2> depInsts;
470   std::optional<unsigned> firstSrcDepPos;
471   std::optional<unsigned> lastDstDepPos;
472   unsigned pos = 0;
473   for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
474        it != Block::iterator(dstNodeInst); ++it) {
475     Operation *op = &(*it);
476     if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
477       firstSrcDepPos = pos;
478     if (dstDepInsts.count(op) > 0)
479       lastDstDepPos = pos;
480     depInsts.push_back(op);
481     ++pos;
482   }
483 
484   if (firstSrcDepPos.has_value()) {
485     if (lastDstDepPos.has_value()) {
486       if (*firstSrcDepPos <= *lastDstDepPos) {
487         // No valid insertion point exists which preserves dependences.
488         return nullptr;
489       }
490     }
491     // Return the insertion point at 'firstSrcDepPos'.
492     return depInsts[*firstSrcDepPos];
493   }
494   // No dependence targets in range (or only dst deps in range), return
495   // 'dstNodInst' insertion point.
496   return dstNodeInst;
497 }
498 
499 // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
500 // taking into account that:
501 //   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
502 //   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
503 //      private memref.
504 void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
505                                         const DenseSet<Value> &privateMemRefs,
506                                         bool removeSrcId) {
507   // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
508   if (inEdges.count(srcId) > 0) {
509     SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
510     for (auto &inEdge : oldInEdges) {
511       // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
512       if (privateMemRefs.count(inEdge.value) == 0)
513         addEdge(inEdge.id, dstId, inEdge.value);
514     }
515   }
516   // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
517   // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
518   if (outEdges.count(srcId) > 0) {
519     SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
520     for (auto &outEdge : oldOutEdges) {
521       // Remove any out edges from 'srcId' to 'dstId' across memrefs.
522       if (outEdge.id == dstId)
523         removeEdge(srcId, outEdge.id, outEdge.value);
524       else if (removeSrcId) {
525         addEdge(dstId, outEdge.id, outEdge.value);
526         removeEdge(srcId, outEdge.id, outEdge.value);
527       }
528     }
529   }
530   // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
531   // replaced by a private memref). These edges could come from nodes
532   // other than 'srcId' which were removed in the previous step.
533   if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
534     SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
535     for (auto &inEdge : oldInEdges)
536       if (privateMemRefs.count(inEdge.value) > 0)
537         removeEdge(inEdge.id, dstId, inEdge.value);
538   }
539 }
540 
541 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
542 // of sibling node 'sibId' into node 'dstId'.
543 void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
544   // For each edge in 'inEdges[sibId]':
545   // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
546   // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
547   if (inEdges.count(sibId) > 0) {
548     SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
549     for (auto &inEdge : oldInEdges) {
550       addEdge(inEdge.id, dstId, inEdge.value);
551       removeEdge(inEdge.id, sibId, inEdge.value);
552     }
553   }
554 
555   // For each edge in 'outEdges[sibId]' to node 'id'
556   // *) Add new edge from 'dstId' to 'outEdge.id'.
557   // *) Remove edge from 'sibId' to 'outEdge.id'.
558   if (outEdges.count(sibId) > 0) {
559     SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
560     for (auto &outEdge : oldOutEdges) {
561       addEdge(dstId, outEdge.id, outEdge.value);
562       removeEdge(sibId, outEdge.id, outEdge.value);
563     }
564   }
565 }
566 
567 // Adds ops in 'loads' and 'stores' to node at 'id'.
568 void MemRefDependenceGraph::addToNode(
569     unsigned id, const SmallVectorImpl<Operation *> &loads,
570     const SmallVectorImpl<Operation *> &stores) {
571   Node *node = getNode(id);
572   llvm::append_range(node->loads, loads);
573   llvm::append_range(node->stores, stores);
574 }
575 
576 void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
577   Node *node = getNode(id);
578   node->loads.clear();
579   node->stores.clear();
580 }
581 
582 // Calls 'callback' for each input edge incident to node 'id' which carries a
583 // memref dependence.
584 void MemRefDependenceGraph::forEachMemRefInputEdge(
585     unsigned id, const std::function<void(Edge)> &callback) {
586   if (inEdges.count(id) > 0)
587     forEachMemRefEdge(inEdges[id], callback);
588 }
589 
590 // Calls 'callback' for each output edge from node 'id' which carries a
591 // memref dependence.
592 void MemRefDependenceGraph::forEachMemRefOutputEdge(
593     unsigned id, const std::function<void(Edge)> &callback) {
594   if (outEdges.count(id) > 0)
595     forEachMemRefEdge(outEdges[id], callback);
596 }
597 
598 // Calls 'callback' for each edge in 'edges' which carries a memref
599 // dependence.
600 void MemRefDependenceGraph::forEachMemRefEdge(
601     ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
602   for (const auto &edge : edges) {
603     // Skip if 'edge' is not a memref dependence edge.
604     if (!isa<MemRefType>(edge.value.getType()))
605       continue;
606     assert(nodes.count(edge.id) > 0);
607     // Skip if 'edge.id' is not a loop nest.
608     if (!isa<AffineForOp>(getNode(edge.id)->op))
609       continue;
610     // Visit current input edge 'edge'.
611     callback(edge);
612   }
613 }
614 
615 void MemRefDependenceGraph::print(raw_ostream &os) const {
616   os << "\nMemRefDependenceGraph\n";
617   os << "\nNodes:\n";
618   for (const auto &idAndNode : nodes) {
619     os << "Node: " << idAndNode.first << "\n";
620     auto it = inEdges.find(idAndNode.first);
621     if (it != inEdges.end()) {
622       for (const auto &e : it->second)
623         os << "  InEdge: " << e.id << " " << e.value << "\n";
624     }
625     it = outEdges.find(idAndNode.first);
626     if (it != outEdges.end()) {
627       for (const auto &e : it->second)
628         os << "  OutEdge: " << e.id << " " << e.value << "\n";
629     }
630   }
631 }
632 
633 void mlir::affine::getAffineForIVs(Operation &op,
634                                    SmallVectorImpl<AffineForOp> *loops) {
635   auto *currOp = op.getParentOp();
636   AffineForOp currAffineForOp;
637   // Traverse up the hierarchy collecting all 'affine.for' operation while
638   // skipping over 'affine.if' operations.
639   while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
640     if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
641       loops->push_back(currAffineForOp);
642     currOp = currOp->getParentOp();
643   }
644   std::reverse(loops->begin(), loops->end());
645 }
646 
647 void mlir::affine::getEnclosingAffineOps(Operation &op,
648                                          SmallVectorImpl<Operation *> *ops) {
649   ops->clear();
650   Operation *currOp = op.getParentOp();
651 
652   // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
653   // affine.parallel operations.
654   while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
655     if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp))
656       ops->push_back(currOp);
657     currOp = currOp->getParentOp();
658   }
659   std::reverse(ops->begin(), ops->end());
660 }
661 
662 // Populates 'cst' with FlatAffineValueConstraints which represent original
663 // domain of the loop bounds that define 'ivs'.
664 LogicalResult ComputationSliceState::getSourceAsConstraints(
665     FlatAffineValueConstraints &cst) const {
666   assert(!ivs.empty() && "Cannot have a slice without its IVs");
667   cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0,
668                                    /*numLocals=*/0, ivs);
669   for (Value iv : ivs) {
670     AffineForOp loop = getForInductionVarOwner(iv);
671     assert(loop && "Expected affine for");
672     if (failed(cst.addAffineForOpDomain(loop)))
673       return failure();
674   }
675   return success();
676 }
677 
678 // Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
679 LogicalResult
680 ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) const {
681   assert(!lbOperands.empty());
682   // Adds src 'ivs' as dimension variables in 'cst'.
683   unsigned numDims = ivs.size();
684   // Adds operands (dst ivs and symbols) as symbols in 'cst'.
685   unsigned numSymbols = lbOperands[0].size();
686 
687   SmallVector<Value, 4> values(ivs);
688   // Append 'ivs' then 'operands' to 'values'.
689   values.append(lbOperands[0].begin(), lbOperands[0].end());
690   *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values);
691 
692   // Add loop bound constraints for values which are loop IVs of the destination
693   // of fusion and equality constraints for symbols which are constants.
694   for (unsigned i = numDims, end = values.size(); i < end; ++i) {
695     Value value = values[i];
696     assert(cst->containsVar(value) && "value expected to be present");
697     if (isValidSymbol(value)) {
698       // Check if the symbol is a constant.
699       if (std::optional<int64_t> cOp = getConstantIntValue(value))
700         cst->addBound(BoundType::EQ, value, cOp.value());
701     } else if (auto loop = getForInductionVarOwner(value)) {
702       if (failed(cst->addAffineForOpDomain(loop)))
703         return failure();
704     }
705   }
706 
707   // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
708   LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
709   assert(succeeded(ret) &&
710          "should not fail as we never have semi-affine slice maps");
711   (void)ret;
712   return success();
713 }
714 
715 // Clears state bounds and operand state.
716 void ComputationSliceState::clearBounds() {
717   lbs.clear();
718   ubs.clear();
719   lbOperands.clear();
720   ubOperands.clear();
721 }
722 
723 void ComputationSliceState::dump() const {
724   llvm::errs() << "\tIVs:\n";
725   for (Value iv : ivs)
726     llvm::errs() << "\t\t" << iv << "\n";
727 
728   llvm::errs() << "\tLBs:\n";
729   for (auto en : llvm::enumerate(lbs)) {
730     llvm::errs() << "\t\t" << en.value() << "\n";
731     llvm::errs() << "\t\tOperands:\n";
732     for (Value lbOp : lbOperands[en.index()])
733       llvm::errs() << "\t\t\t" << lbOp << "\n";
734   }
735 
736   llvm::errs() << "\tUBs:\n";
737   for (auto en : llvm::enumerate(ubs)) {
738     llvm::errs() << "\t\t" << en.value() << "\n";
739     llvm::errs() << "\t\tOperands:\n";
740     for (Value ubOp : ubOperands[en.index()])
741       llvm::errs() << "\t\t\t" << ubOp << "\n";
742   }
743 }
744 
745 /// Fast check to determine if the computation slice is maximal. Returns true if
746 /// each slice dimension maps to an existing dst dimension and both the src
747 /// and the dst loops for those dimensions have the same bounds. Returns false
748 /// if both the src and the dst loops don't have the same bounds. Returns
749 /// std::nullopt if none of the above can be proven.
750 std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
751   assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() &&
752          "Unexpected number of lbs, ubs and ivs in slice");
753 
754   for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
755     AffineMap lbMap = lbs[i];
756     AffineMap ubMap = ubs[i];
757 
758     // Check if this slice is just an equality along this dimension.
759     if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
760         ubMap.getNumResults() != 1 ||
761         lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
762         // The condition above will be true for maps describing a single
763         // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
764         // Make sure we skip those cases by checking that the lb result is not
765         // just a constant.
766         isa<AffineConstantExpr>(lbMap.getResult(0)))
767       return std::nullopt;
768 
769     // Limited support: we expect the lb result to be just a loop dimension for
770     // now.
771     AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0));
772     if (!result)
773       return std::nullopt;
774 
775     // Retrieve dst loop bounds.
776     AffineForOp dstLoop =
777         getForInductionVarOwner(lbOperands[i][result.getPosition()]);
778     if (!dstLoop)
779       return std::nullopt;
780     AffineMap dstLbMap = dstLoop.getLowerBoundMap();
781     AffineMap dstUbMap = dstLoop.getUpperBoundMap();
782 
783     // Retrieve src loop bounds.
784     AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
785     assert(srcLoop && "Expected affine for");
786     AffineMap srcLbMap = srcLoop.getLowerBoundMap();
787     AffineMap srcUbMap = srcLoop.getUpperBoundMap();
788 
789     // Limited support: we expect simple src and dst loops with a single
790     // constant component per bound for now.
791     if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
792         dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
793       return std::nullopt;
794 
795     AffineExpr srcLbResult = srcLbMap.getResult(0);
796     AffineExpr dstLbResult = dstLbMap.getResult(0);
797     AffineExpr srcUbResult = srcUbMap.getResult(0);
798     AffineExpr dstUbResult = dstUbMap.getResult(0);
799     if (!isa<AffineConstantExpr>(srcLbResult) ||
800         !isa<AffineConstantExpr>(srcUbResult) ||
801         !isa<AffineConstantExpr>(dstLbResult) ||
802         !isa<AffineConstantExpr>(dstUbResult))
803       return std::nullopt;
804 
805     // Check if src and dst loop bounds are the same. If not, we can guarantee
806     // that the slice is not maximal.
807     if (srcLbResult != dstLbResult || srcUbResult != dstUbResult ||
808         srcLoop.getStep() != dstLoop.getStep())
809       return false;
810   }
811 
812   return true;
813 }
814 
815 /// Returns true if it is deterministically verified that the original iteration
816 /// space of the slice is contained within the new iteration space that is
817 /// created after fusing 'this' slice into its destination.
818 std::optional<bool> ComputationSliceState::isSliceValid() const {
819   // Fast check to determine if the slice is valid. If the following conditions
820   // are verified to be true, slice is declared valid by the fast check:
821   // 1. Each slice loop is a single iteration loop bound in terms of a single
822   //    destination loop IV.
823   // 2. Loop bounds of the destination loop IV (from above) and those of the
824   //    source loop IV are exactly the same.
825   // If the fast check is inconclusive or false, we proceed with a more
826   // expensive analysis.
827   // TODO: Store the result of the fast check, as it might be used again in
828   // `canRemoveSrcNodeAfterFusion`.
829   std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
830   if (isValidFastCheck && *isValidFastCheck)
831     return true;
832 
833   // Create constraints for the source loop nest using which slice is computed.
834   FlatAffineValueConstraints srcConstraints;
835   // TODO: Store the source's domain to avoid computation at each depth.
836   if (failed(getSourceAsConstraints(srcConstraints))) {
837     LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
838     return std::nullopt;
839   }
840   // As the set difference utility currently cannot handle symbols in its
841   // operands, validity of the slice cannot be determined.
842   if (srcConstraints.getNumSymbolVars() > 0) {
843     LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
844     return std::nullopt;
845   }
846   // TODO: Handle local vars in the source domains while using the 'projectOut'
847   // utility below. Currently, aligning is not done assuming that there will be
848   // no local vars in the source domain.
849   if (srcConstraints.getNumLocalVars() != 0) {
850     LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
851     return std::nullopt;
852   }
853 
854   // Create constraints for the slice loop nest that would be created if the
855   // fusion succeeds.
856   FlatAffineValueConstraints sliceConstraints;
857   if (failed(getAsConstraints(&sliceConstraints))) {
858     LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
859     return std::nullopt;
860   }
861 
862   // Projecting out every dimension other than the 'ivs' to express slice's
863   // domain completely in terms of source's IVs.
864   sliceConstraints.projectOut(ivs.size(),
865                               sliceConstraints.getNumVars() - ivs.size());
866 
867   LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
868   LLVM_DEBUG(srcConstraints.dump());
869   LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
870                              "(expressed in terms of its source's IVs):\n");
871   LLVM_DEBUG(sliceConstraints.dump());
872 
873   // TODO: Store 'srcSet' to avoid recalculating for each depth.
874   PresburgerSet srcSet(srcConstraints);
875   PresburgerSet sliceSet(sliceConstraints);
876   PresburgerSet diffSet = sliceSet.subtract(srcSet);
877 
878   if (!diffSet.isIntegerEmpty()) {
879     LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
880     return false;
881   }
882   return true;
883 }
884 
885 /// Returns true if the computation slice encloses all the iterations of the
886 /// sliced loop nest. Returns false if it does not. Returns std::nullopt if it
887 /// cannot determine if the slice is maximal or not.
888 std::optional<bool> ComputationSliceState::isMaximal() const {
889   // Fast check to determine if the computation slice is maximal. If the result
890   // is inconclusive, we proceed with a more expensive analysis.
891   std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
892   if (isMaximalFastCheck)
893     return isMaximalFastCheck;
894 
895   // Create constraints for the src loop nest being sliced.
896   FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(),
897                                             /*numSymbols=*/0,
898                                             /*numLocals=*/0, ivs);
899   for (Value iv : ivs) {
900     AffineForOp loop = getForInductionVarOwner(iv);
901     assert(loop && "Expected affine for");
902     if (failed(srcConstraints.addAffineForOpDomain(loop)))
903       return std::nullopt;
904   }
905 
906   // Create constraints for the slice using the dst loop nest information. We
907   // retrieve existing dst loops from the lbOperands.
908   SmallVector<Value> consumerIVs;
909   for (Value lbOp : lbOperands[0])
910     if (getForInductionVarOwner(lbOp))
911       consumerIVs.push_back(lbOp);
912 
913   // Add empty IV Values for those new loops that are not equalities and,
914   // therefore, are not yet materialized in the IR.
915   for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
916     consumerIVs.push_back(Value());
917 
918   FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(),
919                                               /*numSymbols=*/0,
920                                               /*numLocals=*/0, consumerIVs);
921 
922   if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
923     return std::nullopt;
924 
925   if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars())
926     // Constraint dims are different. The integer set difference can't be
927     // computed so we don't know if the slice is maximal.
928     return std::nullopt;
929 
930   // Compute the difference between the src loop nest and the slice integer
931   // sets.
932   PresburgerSet srcSet(srcConstraints);
933   PresburgerSet sliceSet(sliceConstraints);
934   PresburgerSet diffSet = srcSet.subtract(sliceSet);
935   return diffSet.isIntegerEmpty();
936 }
937 
938 unsigned MemRefRegion::getRank() const {
939   return cast<MemRefType>(memref.getType()).getRank();
940 }
941 
942 std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
943     SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
944     SmallVectorImpl<int64_t> *lbDivisors) const {
945   auto memRefType = cast<MemRefType>(memref.getType());
946   unsigned rank = memRefType.getRank();
947   if (shape)
948     shape->reserve(rank);
949 
950   assert(rank == cst.getNumDimVars() && "inconsistent memref region");
951 
952   // Use a copy of the region constraints that has upper/lower bounds for each
953   // memref dimension with static size added to guard against potential
954   // over-approximation from projection or union bounding box. We may not add
955   // this on the region itself since they might just be redundant constraints
956   // that will need non-trivials means to eliminate.
957   FlatAffineValueConstraints cstWithShapeBounds(cst);
958   for (unsigned r = 0; r < rank; r++) {
959     cstWithShapeBounds.addBound(BoundType::LB, r, 0);
960     int64_t dimSize = memRefType.getDimSize(r);
961     if (ShapedType::isDynamic(dimSize))
962       continue;
963     cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1);
964   }
965 
966   // Find a constant upper bound on the extent of this memref region along each
967   // dimension.
968   int64_t numElements = 1;
969   int64_t diffConstant;
970   int64_t lbDivisor;
971   for (unsigned d = 0; d < rank; d++) {
972     SmallVector<int64_t, 4> lb;
973     std::optional<int64_t> diff =
974         cstWithShapeBounds.getConstantBoundOnDimSize64(d, &lb, &lbDivisor);
975     if (diff.has_value()) {
976       diffConstant = *diff;
977       assert(diffConstant >= 0 && "Dim size bound can't be negative");
978       assert(lbDivisor > 0);
979     } else {
980       // If no constant bound is found, then it can always be bound by the
981       // memref's dim size if the latter has a constant size along this dim.
982       auto dimSize = memRefType.getDimSize(d);
983       if (dimSize == ShapedType::kDynamic)
984         return std::nullopt;
985       diffConstant = dimSize;
986       // Lower bound becomes 0.
987       lb.resize(cstWithShapeBounds.getNumSymbolVars() + 1, 0);
988       lbDivisor = 1;
989     }
990     numElements *= diffConstant;
991     if (lbs) {
992       lbs->push_back(lb);
993       assert(lbDivisors && "both lbs and lbDivisor or none");
994       lbDivisors->push_back(lbDivisor);
995     }
996     if (shape) {
997       shape->push_back(diffConstant);
998     }
999   }
1000   return numElements;
1001 }
1002 
1003 void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
1004                                          AffineMap &ubMap) const {
1005   assert(pos < cst.getNumDimVars() && "invalid position");
1006   auto memRefType = cast<MemRefType>(memref.getType());
1007   unsigned rank = memRefType.getRank();
1008 
1009   assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1010 
1011   auto boundPairs = cst.getLowerAndUpperBound(
1012       pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolVars(),
1013       /*localExprs=*/{}, memRefType.getContext());
1014   lbMap = boundPairs.first;
1015   ubMap = boundPairs.second;
1016   assert(lbMap && "lower bound for a region must exist");
1017   assert(ubMap && "upper bound for a region must exist");
1018   assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1019   assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1020 }
1021 
1022 LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
1023   assert(memref == other.memref);
1024   return cst.unionBoundingBox(*other.getConstraints());
1025 }
1026 
1027 /// Computes the memory region accessed by this memref with the region
1028 /// represented as constraints symbolic/parametric in 'loopDepth' loops
1029 /// surrounding opInst and any additional Function symbols.
1030 //  For example, the memref region for this load operation at loopDepth = 1 will
1031 //  be as below:
1032 //
1033 //    affine.for %i = 0 to 32 {
1034 //      affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
1035 //        load %A[%ii]
1036 //      }
1037 //    }
1038 //
1039 // region:  {memref = %A, write = false, {%i <= m0 <= %i + 7} }
1040 // The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
1041 //
1042 // TODO: extend this to any other memref dereferencing ops
1043 // (dma_start, dma_wait).
1044 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
1045                                     const ComputationSliceState *sliceState,
1046                                     bool addMemRefDimBounds) {
1047   assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
1048          "affine read/write op expected");
1049 
1050   MemRefAccess access(op);
1051   memref = access.memref;
1052   write = access.isStore();
1053 
1054   unsigned rank = access.getRank();
1055 
1056   LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
1057                           << "\ndepth: " << loopDepth << "\n";);
1058 
1059   // 0-d memrefs.
1060   if (rank == 0) {
1061     SmallVector<Value, 4> ivs;
1062     getAffineIVs(*op, ivs);
1063     assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
1064     // The first 'loopDepth' IVs are symbols for this region.
1065     ivs.resize(loopDepth);
1066     // A 0-d memref has a 0-d region.
1067     cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs);
1068     return success();
1069   }
1070 
1071   // Build the constraints for this region.
1072   AffineValueMap accessValueMap;
1073   access.getAccessMap(&accessValueMap);
1074   AffineMap accessMap = accessValueMap.getAffineMap();
1075 
1076   unsigned numDims = accessMap.getNumDims();
1077   unsigned numSymbols = accessMap.getNumSymbols();
1078   unsigned numOperands = accessValueMap.getNumOperands();
1079   // Merge operands with slice operands.
1080   SmallVector<Value, 4> operands;
1081   operands.resize(numOperands);
1082   for (unsigned i = 0; i < numOperands; ++i)
1083     operands[i] = accessValueMap.getOperand(i);
1084 
1085   if (sliceState != nullptr) {
1086     operands.reserve(operands.size() + sliceState->lbOperands[0].size());
1087     // Append slice operands to 'operands' as symbols.
1088     for (auto extraOperand : sliceState->lbOperands[0]) {
1089       if (!llvm::is_contained(operands, extraOperand)) {
1090         operands.push_back(extraOperand);
1091         numSymbols++;
1092       }
1093     }
1094   }
1095   // We'll first associate the dims and symbols of the access map to the dims
1096   // and symbols resp. of cst. This will change below once cst is
1097   // fully constructed out.
1098   cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands);
1099 
1100   // Add equality constraints.
1101   // Add inequalities for loop lower/upper bounds.
1102   for (unsigned i = 0; i < numDims + numSymbols; ++i) {
1103     auto operand = operands[i];
1104     if (auto affineFor = getForInductionVarOwner(operand)) {
1105       // Note that cst can now have more dimensions than accessMap if the
1106       // bounds expressions involve outer loops or other symbols.
1107       // TODO: rewrite this to use getInstIndexSet; this way
1108       // conditionals will be handled when the latter supports it.
1109       if (failed(cst.addAffineForOpDomain(affineFor)))
1110         return failure();
1111     } else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) {
1112       if (failed(cst.addAffineParallelOpDomain(parallelOp)))
1113         return failure();
1114     } else if (isValidSymbol(operand)) {
1115       // Check if the symbol is a constant.
1116       Value symbol = operand;
1117       if (auto constVal = getConstantIntValue(symbol))
1118         cst.addBound(BoundType::EQ, symbol, constVal.value());
1119     } else {
1120       LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
1121       return failure();
1122     }
1123   }
1124 
1125   // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
1126   if (sliceState != nullptr) {
1127     // Add dim and symbol slice operands.
1128     for (auto operand : sliceState->lbOperands[0]) {
1129       cst.addInductionVarOrTerminalSymbol(operand);
1130     }
1131     // Add upper/lower bounds from 'sliceState' to 'cst'.
1132     LogicalResult ret =
1133         cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
1134                            sliceState->lbOperands[0]);
1135     assert(succeeded(ret) &&
1136            "should not fail as we never have semi-affine slice maps");
1137     (void)ret;
1138   }
1139 
1140   // Add access function equalities to connect loop IVs to data dimensions.
1141   if (failed(cst.composeMap(&accessValueMap))) {
1142     op->emitError("getMemRefRegion: compose affine map failed");
1143     LLVM_DEBUG(accessValueMap.getAffineMap().dump());
1144     return failure();
1145   }
1146 
1147   // Set all variables appearing after the first 'rank' variables as
1148   // symbolic variables - so that the ones corresponding to the memref
1149   // dimensions are the dimensional variables for the memref region.
1150   cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank);
1151 
1152   // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
1153   // this memref region is symbolic.
1154   SmallVector<Value, 4> enclosingIVs;
1155   getAffineIVs(*op, enclosingIVs);
1156   assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
1157   enclosingIVs.resize(loopDepth);
1158   SmallVector<Value, 4> vars;
1159   cst.getValues(cst.getNumDimVars(), cst.getNumDimAndSymbolVars(), &vars);
1160   for (Value var : vars) {
1161     if ((isAffineInductionVar(var)) && !llvm::is_contained(enclosingIVs, var)) {
1162       cst.projectOut(var);
1163     }
1164   }
1165 
1166   // Project out any local variables (these would have been added for any
1167   // mod/divs).
1168   cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars());
1169 
1170   // Constant fold any symbolic variables.
1171   cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(),
1172                            /*num=*/cst.getNumSymbolVars());
1173 
1174   assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format");
1175 
1176   // Add upper/lower bounds for each memref dimension with static size
1177   // to guard against potential over-approximation from projection.
1178   // TODO: Support dynamic memref dimensions.
1179   if (addMemRefDimBounds) {
1180     auto memRefType = cast<MemRefType>(memref.getType());
1181     for (unsigned r = 0; r < rank; r++) {
1182       cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0);
1183       if (memRefType.isDynamicDim(r))
1184         continue;
1185       cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1);
1186     }
1187   }
1188   cst.removeTrivialRedundancy();
1189 
1190   LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
1191   LLVM_DEBUG(cst.dump());
1192   return success();
1193 }
1194 
1195 std::optional<int64_t>
1196 mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
1197   auto elementType = memRefType.getElementType();
1198 
1199   unsigned sizeInBits;
1200   if (elementType.isIntOrFloat()) {
1201     sizeInBits = elementType.getIntOrFloatBitWidth();
1202   } else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
1203     if (vectorType.getElementType().isIntOrFloat())
1204       sizeInBits =
1205           vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1206     else
1207       return std::nullopt;
1208   } else {
1209     return std::nullopt;
1210   }
1211   return llvm::divideCeil(sizeInBits, 8);
1212 }
1213 
1214 // Returns the size of the region.
1215 std::optional<int64_t> MemRefRegion::getRegionSize() {
1216   auto memRefType = cast<MemRefType>(memref.getType());
1217 
1218   if (!memRefType.getLayout().isIdentity()) {
1219     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
1220     return false;
1221   }
1222 
1223   // Indices to use for the DmaStart op.
1224   // Indices for the original memref being DMAed from/to.
1225   SmallVector<Value, 4> memIndices;
1226   // Indices for the faster buffer being DMAed into/from.
1227   SmallVector<Value, 4> bufIndices;
1228 
1229   // Compute the extents of the buffer.
1230   std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
1231   if (!numElements) {
1232     LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
1233     return std::nullopt;
1234   }
1235   auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1236   if (!eltSize)
1237     return std::nullopt;
1238   return *eltSize * *numElements;
1239 }
1240 
1241 /// Returns the size of memref data in bytes if it's statically shaped,
1242 /// std::nullopt otherwise.  If the element of the memref has vector type, takes
1243 /// into account size of the vector as well.
1244 //  TODO: improve/complete this when we have target data.
1245 std::optional<uint64_t>
1246 mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) {
1247   if (!memRefType.hasStaticShape())
1248     return std::nullopt;
1249   auto elementType = memRefType.getElementType();
1250   if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
1251     return std::nullopt;
1252 
1253   auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1254   if (!sizeInBytes)
1255     return std::nullopt;
1256   for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
1257     sizeInBytes = *sizeInBytes * memRefType.getDimSize(i);
1258   }
1259   return sizeInBytes;
1260 }
1261 
1262 template <typename LoadOrStoreOp>
1263 LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
1264                                                     bool emitError) {
1265   static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
1266                                 AffineWriteOpInterface>::value,
1267                 "argument should be either a AffineReadOpInterface or a "
1268                 "AffineWriteOpInterface");
1269 
1270   Operation *op = loadOrStoreOp.getOperation();
1271   MemRefRegion region(op->getLoc());
1272   if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
1273                             /*addMemRefDimBounds=*/false)))
1274     return success();
1275 
1276   LLVM_DEBUG(llvm::dbgs() << "Memory region");
1277   LLVM_DEBUG(region.getConstraints()->dump());
1278 
1279   bool outOfBounds = false;
1280   unsigned rank = loadOrStoreOp.getMemRefType().getRank();
1281 
1282   // For each dimension, check for out of bounds.
1283   for (unsigned r = 0; r < rank; r++) {
1284     FlatAffineValueConstraints ucst(*region.getConstraints());
1285 
1286     // Intersect memory region with constraint capturing out of bounds (both out
1287     // of upper and out of lower), and check if the constraint system is
1288     // feasible. If it is, there is at least one point out of bounds.
1289     SmallVector<int64_t, 4> ineq(rank + 1, 0);
1290     int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
1291     // TODO: handle dynamic dim sizes.
1292     if (dimSize == -1)
1293       continue;
1294 
1295     // Check for overflow: d_i >= memref dim size.
1296     ucst.addBound(BoundType::LB, r, dimSize);
1297     outOfBounds = !ucst.isEmpty();
1298     if (outOfBounds && emitError) {
1299       loadOrStoreOp.emitOpError()
1300           << "memref out of upper bound access along dimension #" << (r + 1);
1301     }
1302 
1303     // Check for a negative index.
1304     FlatAffineValueConstraints lcst(*region.getConstraints());
1305     std::fill(ineq.begin(), ineq.end(), 0);
1306     // d_i <= -1;
1307     lcst.addBound(BoundType::UB, r, -1);
1308     outOfBounds = !lcst.isEmpty();
1309     if (outOfBounds && emitError) {
1310       loadOrStoreOp.emitOpError()
1311           << "memref out of lower bound access along dimension #" << (r + 1);
1312     }
1313   }
1314   return failure(outOfBounds);
1315 }
1316 
1317 // Explicitly instantiate the template so that the compiler knows we need them!
1318 template LogicalResult
1319 mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp,
1320                                       bool emitError);
1321 template LogicalResult
1322 mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp,
1323                                       bool emitError);
1324 
1325 // Returns in 'positions' the Block positions of 'op' in each ancestor
1326 // Block from the Block containing operation, stopping at 'limitBlock'.
1327 static void findInstPosition(Operation *op, Block *limitBlock,
1328                              SmallVectorImpl<unsigned> *positions) {
1329   Block *block = op->getBlock();
1330   while (block != limitBlock) {
1331     // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
1332     // rely on linear scans.
1333     int instPosInBlock = std::distance(block->begin(), op->getIterator());
1334     positions->push_back(instPosInBlock);
1335     op = block->getParentOp();
1336     block = op->getBlock();
1337   }
1338   std::reverse(positions->begin(), positions->end());
1339 }
1340 
1341 // Returns the Operation in a possibly nested set of Blocks, where the
1342 // position of the operation is represented by 'positions', which has a
1343 // Block position for each level of nesting.
1344 static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
1345                                     unsigned level, Block *block) {
1346   unsigned i = 0;
1347   for (auto &op : *block) {
1348     if (i != positions[level]) {
1349       ++i;
1350       continue;
1351     }
1352     if (level == positions.size() - 1)
1353       return &op;
1354     if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
1355       return getInstAtPosition(positions, level + 1,
1356                                childAffineForOp.getBody());
1357 
1358     for (auto &region : op.getRegions()) {
1359       for (auto &b : region)
1360         if (auto *ret = getInstAtPosition(positions, level + 1, &b))
1361           return ret;
1362     }
1363     return nullptr;
1364   }
1365   return nullptr;
1366 }
1367 
1368 // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
1369 static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
1370                                             FlatAffineValueConstraints *cst) {
1371   for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) {
1372     auto value = cst->getValue(i);
1373     if (ivs.count(value) == 0) {
1374       assert(isAffineForInductionVar(value));
1375       auto loop = getForInductionVarOwner(value);
1376       if (failed(cst->addAffineForOpDomain(loop)))
1377         return failure();
1378     }
1379   }
1380   return success();
1381 }
1382 
1383 /// Returns the innermost common loop depth for the set of operations in 'ops'.
1384 // TODO: Move this to LoopUtils.
1385 unsigned mlir::affine::getInnermostCommonLoopDepth(
1386     ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
1387   unsigned numOps = ops.size();
1388   assert(numOps > 0 && "Expected at least one operation");
1389 
1390   std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
1391   unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
1392   for (unsigned i = 0; i < numOps; ++i) {
1393     getAffineForIVs(*ops[i], &loops[i]);
1394     loopDepthLimit =
1395         std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
1396   }
1397 
1398   unsigned loopDepth = 0;
1399   for (unsigned d = 0; d < loopDepthLimit; ++d) {
1400     unsigned i;
1401     for (i = 1; i < numOps; ++i) {
1402       if (loops[i - 1][d] != loops[i][d])
1403         return loopDepth;
1404     }
1405     if (surroundingLoops)
1406       surroundingLoops->push_back(loops[i - 1][d]);
1407     ++loopDepth;
1408   }
1409   return loopDepth;
1410 }
1411 
1412 /// Computes in 'sliceUnion' the union of all slice bounds computed at
1413 /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
1414 /// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
1415 /// union was computed correctly, an appropriate failure otherwise.
1416 SliceComputationResult
1417 mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
1418                                 ArrayRef<Operation *> opsB, unsigned loopDepth,
1419                                 unsigned numCommonLoops, bool isBackwardSlice,
1420                                 ComputationSliceState *sliceUnion) {
1421   // Compute the union of slice bounds between all pairs in 'opsA' and
1422   // 'opsB' in 'sliceUnionCst'.
1423   FlatAffineValueConstraints sliceUnionCst;
1424   assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
1425   std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
1426   for (auto *i : opsA) {
1427     MemRefAccess srcAccess(i);
1428     for (auto *j : opsB) {
1429       MemRefAccess dstAccess(j);
1430       if (srcAccess.memref != dstAccess.memref)
1431         continue;
1432       // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
1433       if ((!isBackwardSlice && loopDepth > getNestingDepth(i)) ||
1434           (isBackwardSlice && loopDepth > getNestingDepth(j))) {
1435         LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
1436         return SliceComputationResult::GenericFailure;
1437       }
1438 
1439       bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
1440                               isa<AffineReadOpInterface>(dstAccess.opInst);
1441       FlatAffineValueConstraints dependenceConstraints;
1442       // Check dependence between 'srcAccess' and 'dstAccess'.
1443       DependenceResult result = checkMemrefAccessDependence(
1444           srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
1445           &dependenceConstraints, /*dependenceComponents=*/nullptr,
1446           /*allowRAR=*/readReadAccesses);
1447       if (result.value == DependenceResult::Failure) {
1448         LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
1449         return SliceComputationResult::GenericFailure;
1450       }
1451       if (result.value == DependenceResult::NoDependence)
1452         continue;
1453       dependentOpPairs.emplace_back(i, j);
1454 
1455       // Compute slice bounds for 'srcAccess' and 'dstAccess'.
1456       ComputationSliceState tmpSliceState;
1457       mlir::affine::getComputationSliceState(i, j, &dependenceConstraints,
1458                                              loopDepth, isBackwardSlice,
1459                                              &tmpSliceState);
1460 
1461       if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1462         // Initialize 'sliceUnionCst' with the bounds computed in previous step.
1463         if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
1464           LLVM_DEBUG(llvm::dbgs()
1465                      << "Unable to compute slice bound constraints\n");
1466           return SliceComputationResult::GenericFailure;
1467         }
1468         assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
1469         continue;
1470       }
1471 
1472       // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
1473       FlatAffineValueConstraints tmpSliceCst;
1474       if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
1475         LLVM_DEBUG(llvm::dbgs()
1476                    << "Unable to compute slice bound constraints\n");
1477         return SliceComputationResult::GenericFailure;
1478       }
1479 
1480       // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
1481       if (!sliceUnionCst.areVarsAlignedWithOther(tmpSliceCst)) {
1482 
1483         // Pre-constraint var alignment: record loop IVs used in each constraint
1484         // system.
1485         SmallPtrSet<Value, 8> sliceUnionIVs;
1486         for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k)
1487           sliceUnionIVs.insert(sliceUnionCst.getValue(k));
1488         SmallPtrSet<Value, 8> tmpSliceIVs;
1489         for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k)
1490           tmpSliceIVs.insert(tmpSliceCst.getValue(k));
1491 
1492         sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, &tmpSliceCst);
1493 
1494         // Post-constraint var alignment: add loop IV bounds missing after
1495         // var alignment to constraint systems. This can occur if one constraint
1496         // system uses an loop IV that is not used by the other. The call
1497         // to unionBoundingBox below expects constraints for each Loop IV, even
1498         // if they are the unsliced full loop bounds added here.
1499         if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
1500           return SliceComputationResult::GenericFailure;
1501         if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
1502           return SliceComputationResult::GenericFailure;
1503       }
1504       // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
1505       if (sliceUnionCst.getNumLocalVars() > 0 ||
1506           tmpSliceCst.getNumLocalVars() > 0 ||
1507           failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
1508         LLVM_DEBUG(llvm::dbgs()
1509                    << "Unable to compute union bounding box of slice bounds\n");
1510         return SliceComputationResult::GenericFailure;
1511       }
1512     }
1513   }
1514 
1515   // Empty union.
1516   if (sliceUnionCst.getNumDimAndSymbolVars() == 0)
1517     return SliceComputationResult::GenericFailure;
1518 
1519   // Gather loops surrounding ops from loop nest where slice will be inserted.
1520   SmallVector<Operation *, 4> ops;
1521   for (auto &dep : dependentOpPairs) {
1522     ops.push_back(isBackwardSlice ? dep.second : dep.first);
1523   }
1524   SmallVector<AffineForOp, 4> surroundingLoops;
1525   unsigned innermostCommonLoopDepth =
1526       getInnermostCommonLoopDepth(ops, &surroundingLoops);
1527   if (loopDepth > innermostCommonLoopDepth) {
1528     LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
1529     return SliceComputationResult::GenericFailure;
1530   }
1531 
1532   // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
1533   unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars();
1534 
1535   // Convert any dst loop IVs which are symbol variables to dim variables.
1536   sliceUnionCst.convertLoopIVSymbolsToDims();
1537   sliceUnion->clearBounds();
1538   sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
1539   sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
1540 
1541   // Get slice bounds from slice union constraints 'sliceUnionCst'.
1542   sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
1543                                opsA[0]->getContext(), &sliceUnion->lbs,
1544                                &sliceUnion->ubs);
1545 
1546   // Add slice bound operands of union.
1547   SmallVector<Value, 4> sliceBoundOperands;
1548   sliceUnionCst.getValues(numSliceLoopIVs,
1549                           sliceUnionCst.getNumDimAndSymbolVars(),
1550                           &sliceBoundOperands);
1551 
1552   // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
1553   sliceUnion->ivs.clear();
1554   sliceUnionCst.getValues(0, numSliceLoopIVs, &sliceUnion->ivs);
1555 
1556   // Set loop nest insertion point to block start at 'loopDepth'.
1557   sliceUnion->insertPoint =
1558       isBackwardSlice
1559           ? surroundingLoops[loopDepth - 1].getBody()->begin()
1560           : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
1561 
1562   // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1563   // canonicalization.
1564   sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1565   sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1566 
1567   // Check if the slice computed is valid. Return success only if it is verified
1568   // that the slice is valid, otherwise return appropriate failure status.
1569   std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
1570   if (!isSliceValid) {
1571     LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
1572     return SliceComputationResult::GenericFailure;
1573   }
1574   if (!*isSliceValid)
1575     return SliceComputationResult::IncorrectSliceFailure;
1576 
1577   return SliceComputationResult::Success;
1578 }
1579 
1580 // TODO: extend this to handle multiple result maps.
1581 static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
1582                                                   AffineMap ubMap) {
1583   assert(lbMap.getNumResults() == 1 && "expected single result bound map");
1584   assert(ubMap.getNumResults() == 1 && "expected single result bound map");
1585   assert(lbMap.getNumDims() == ubMap.getNumDims());
1586   assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
1587   AffineExpr lbExpr(lbMap.getResult(0));
1588   AffineExpr ubExpr(ubMap.getResult(0));
1589   auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
1590                                          lbMap.getNumSymbols());
1591   auto cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
1592   if (!cExpr)
1593     return std::nullopt;
1594   return cExpr.getValue();
1595 }
1596 
1597 // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
1598 // nest surrounding represented by slice loop bounds in 'slice'. Returns true
1599 // on success, false otherwise (if a non-constant trip count was encountered).
1600 // TODO: Make this work with non-unit step loops.
1601 bool mlir::affine::buildSliceTripCountMap(
1602     const ComputationSliceState &slice,
1603     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
1604   unsigned numSrcLoopIVs = slice.ivs.size();
1605   // Populate map from AffineForOp -> trip count
1606   for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1607     AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
1608     auto *op = forOp.getOperation();
1609     AffineMap lbMap = slice.lbs[i];
1610     AffineMap ubMap = slice.ubs[i];
1611     // If lower or upper bound maps are null or provide no results, it implies
1612     // that source loop was not at all sliced, and the entire loop will be a
1613     // part of the slice.
1614     if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1615         ubMap.getNumResults() == 0) {
1616       // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1617       if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1618         (*tripCountMap)[op] =
1619             forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1620         continue;
1621       }
1622       std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1623       if (maybeConstTripCount.has_value()) {
1624         (*tripCountMap)[op] = *maybeConstTripCount;
1625         continue;
1626       }
1627       return false;
1628     }
1629     std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1630     // Slice bounds are created with a constant ub - lb difference.
1631     if (!tripCount.has_value())
1632       return false;
1633     (*tripCountMap)[op] = *tripCount;
1634   }
1635   return true;
1636 }
1637 
1638 // Return the number of iterations in the given slice.
1639 uint64_t mlir::affine::getSliceIterationCount(
1640     const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1641   uint64_t iterCount = 1;
1642   for (const auto &count : sliceTripCountMap) {
1643     iterCount *= count.second;
1644   }
1645   return iterCount;
1646 }
1647 
1648 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1649 // Computes slice bounds by projecting out any loop IVs from
1650 // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1651 // bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1652 // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
1653 void mlir::affine::getComputationSliceState(
1654     Operation *depSourceOp, Operation *depSinkOp,
1655     FlatAffineValueConstraints *dependenceConstraints, unsigned loopDepth,
1656     bool isBackwardSlice, ComputationSliceState *sliceState) {
1657   // Get loop nest surrounding src operation.
1658   SmallVector<AffineForOp, 4> srcLoopIVs;
1659   getAffineForIVs(*depSourceOp, &srcLoopIVs);
1660   unsigned numSrcLoopIVs = srcLoopIVs.size();
1661 
1662   // Get loop nest surrounding dst operation.
1663   SmallVector<AffineForOp, 4> dstLoopIVs;
1664   getAffineForIVs(*depSinkOp, &dstLoopIVs);
1665   unsigned numDstLoopIVs = dstLoopIVs.size();
1666 
1667   assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1668          (isBackwardSlice && loopDepth <= numDstLoopIVs));
1669 
1670   // Project out dimensions other than those up to 'loopDepth'.
1671   unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1672   unsigned num =
1673       isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1674   dependenceConstraints->projectOut(pos, num);
1675 
1676   // Add slice loop IV values to 'sliceState'.
1677   unsigned offset = isBackwardSlice ? 0 : loopDepth;
1678   unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1679   dependenceConstraints->getValues(offset, offset + numSliceLoopIVs,
1680                                    &sliceState->ivs);
1681 
1682   // Set up lower/upper bound affine maps for the slice.
1683   sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
1684   sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
1685 
1686   // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1687   dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
1688                                         depSourceOp->getContext(),
1689                                         &sliceState->lbs, &sliceState->ubs);
1690 
1691   // Set up bound operands for the slice's lower and upper bounds.
1692   SmallVector<Value, 4> sliceBoundOperands;
1693   unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolVars();
1694   for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1695     if (i < offset || i >= offset + numSliceLoopIVs) {
1696       sliceBoundOperands.push_back(dependenceConstraints->getValue(i));
1697     }
1698   }
1699 
1700   // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1701   // canonicalization.
1702   sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1703   sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1704 
1705   // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1706   sliceState->insertPoint =
1707       isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1708                       : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
1709 
1710   llvm::SmallDenseSet<Value, 8> sequentialLoops;
1711   if (isa<AffineReadOpInterface>(depSourceOp) &&
1712       isa<AffineReadOpInterface>(depSinkOp)) {
1713     // For read-read access pairs, clear any slice bounds on sequential loops.
1714     // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1715     getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1716                        &sequentialLoops);
1717   }
1718   auto getSliceLoop = [&](unsigned i) {
1719     return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1720   };
1721   auto isInnermostInsertion = [&]() {
1722     return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1723                             : loopDepth >= dstLoopIVs.size());
1724   };
1725   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1726   auto srcIsUnitSlice = [&]() {
1727     return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) &&
1728             (getSliceIterationCount(sliceTripCountMap) == 1));
1729   };
1730   // Clear all sliced loop bounds beginning at the first sequential loop, or
1731   // first loop with a slice fusion barrier attribute..
1732 
1733   for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1734     Value iv = getSliceLoop(i).getInductionVar();
1735     if (sequentialLoops.count(iv) == 0 &&
1736         getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
1737       continue;
1738     // Skip reset of bounds of reduction loop inserted in the destination loop
1739     // that meets the following conditions:
1740     //    1. Slice is  single trip count.
1741     //    2. Loop bounds of the source and destination match.
1742     //    3. Is being inserted at the innermost insertion point.
1743     std::optional<bool> isMaximal = sliceState->isMaximal();
1744     if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
1745         isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal)
1746       continue;
1747     for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1748       sliceState->lbs[j] = AffineMap();
1749       sliceState->ubs[j] = AffineMap();
1750     }
1751     break;
1752   }
1753 }
1754 
1755 /// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1756 /// updates the slice loop bounds with any non-null bound maps specified in
1757 /// 'sliceState', and inserts this slice into the loop nest surrounding
1758 /// 'dstOpInst' at loop depth 'dstLoopDepth'.
1759 // TODO: extend the slicing utility to compute slices that
1760 // aren't necessarily a one-to-one relation b/w the source and destination. The
1761 // relation between the source and destination could be many-to-many in general.
1762 // TODO: the slice computation is incorrect in the cases
1763 // where the dependence from the source to the destination does not cover the
1764 // entire destination index set. Subtract out the dependent destination
1765 // iterations from destination index set and check for emptiness --- this is one
1766 // solution.
1767 AffineForOp mlir::affine::insertBackwardComputationSlice(
1768     Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth,
1769     ComputationSliceState *sliceState) {
1770   // Get loop nest surrounding src operation.
1771   SmallVector<AffineForOp, 4> srcLoopIVs;
1772   getAffineForIVs(*srcOpInst, &srcLoopIVs);
1773   unsigned numSrcLoopIVs = srcLoopIVs.size();
1774 
1775   // Get loop nest surrounding dst operation.
1776   SmallVector<AffineForOp, 4> dstLoopIVs;
1777   getAffineForIVs(*dstOpInst, &dstLoopIVs);
1778   unsigned dstLoopIVsSize = dstLoopIVs.size();
1779   if (dstLoopDepth > dstLoopIVsSize) {
1780     dstOpInst->emitError("invalid destination loop depth");
1781     return AffineForOp();
1782   }
1783 
1784   // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
1785   SmallVector<unsigned, 4> positions;
1786   // TODO: This code is incorrect since srcLoopIVs can be 0-d.
1787   findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
1788 
1789   // Clone src loop nest and insert it a the beginning of the operation block
1790   // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
1791   auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
1792   OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
1793   auto sliceLoopNest =
1794       cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
1795 
1796   Operation *sliceInst =
1797       getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
1798   // Get loop nest surrounding 'sliceInst'.
1799   SmallVector<AffineForOp, 4> sliceSurroundingLoops;
1800   getAffineForIVs(*sliceInst, &sliceSurroundingLoops);
1801 
1802   // Sanity check.
1803   unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
1804   (void)sliceSurroundingLoopsSize;
1805   assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
1806   unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
1807   (void)sliceLoopLimit;
1808   assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
1809 
1810   // Update loop bounds for loops in 'sliceLoopNest'.
1811   for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1812     auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
1813     if (AffineMap lbMap = sliceState->lbs[i])
1814       forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
1815     if (AffineMap ubMap = sliceState->ubs[i])
1816       forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
1817   }
1818   return sliceLoopNest;
1819 }
1820 
1821 // Constructs  MemRefAccess populating it with the memref, its indices and
1822 // opinst from 'loadOrStoreOpInst'.
1823 MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
1824   if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
1825     memref = loadOp.getMemRef();
1826     opInst = loadOrStoreOpInst;
1827     llvm::append_range(indices, loadOp.getMapOperands());
1828   } else {
1829     assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
1830            "Affine read/write op expected");
1831     auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
1832     opInst = loadOrStoreOpInst;
1833     memref = storeOp.getMemRef();
1834     llvm::append_range(indices, storeOp.getMapOperands());
1835   }
1836 }
1837 
1838 unsigned MemRefAccess::getRank() const {
1839   return cast<MemRefType>(memref.getType()).getRank();
1840 }
1841 
1842 bool MemRefAccess::isStore() const {
1843   return isa<AffineWriteOpInterface>(opInst);
1844 }
1845 
1846 /// Returns the nesting depth of this statement, i.e., the number of loops
1847 /// surrounding this statement.
1848 unsigned mlir::affine::getNestingDepth(Operation *op) {
1849   Operation *currOp = op;
1850   unsigned depth = 0;
1851   while ((currOp = currOp->getParentOp())) {
1852     if (isa<AffineForOp>(currOp))
1853       depth++;
1854   }
1855   return depth;
1856 }
1857 
1858 /// Equal if both affine accesses are provably equivalent (at compile
1859 /// time) when considering the memref, the affine maps and their respective
1860 /// operands. The equality of access functions + operands is checked by
1861 /// subtracting fully composed value maps, and then simplifying the difference
1862 /// using the expression flattener.
1863 /// TODO: this does not account for aliasing of memrefs.
1864 bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
1865   if (memref != rhs.memref)
1866     return false;
1867 
1868   AffineValueMap diff, thisMap, rhsMap;
1869   getAccessMap(&thisMap);
1870   rhs.getAccessMap(&rhsMap);
1871   AffineValueMap::difference(thisMap, rhsMap, &diff);
1872   return llvm::all_of(diff.getAffineMap().getResults(),
1873                       [](AffineExpr e) { return e == 0; });
1874 }
1875 
1876 void mlir::affine::getAffineIVs(Operation &op, SmallVectorImpl<Value> &ivs) {
1877   auto *currOp = op.getParentOp();
1878   AffineForOp currAffineForOp;
1879   // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel
1880   // operation while skipping over 'affine.if' operations.
1881   while (currOp) {
1882     if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
1883       ivs.push_back(currAffineForOp.getInductionVar());
1884     else if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
1885       llvm::append_range(ivs, parOp.getIVs());
1886     currOp = currOp->getParentOp();
1887   }
1888   std::reverse(ivs.begin(), ivs.end());
1889 }
1890 
1891 /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
1892 /// where each lists loops from outer-most to inner-most in loop nest.
1893 unsigned mlir::affine::getNumCommonSurroundingLoops(Operation &a,
1894                                                     Operation &b) {
1895   SmallVector<Value, 4> loopsA, loopsB;
1896   getAffineIVs(a, loopsA);
1897   getAffineIVs(b, loopsB);
1898 
1899   unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
1900   unsigned numCommonLoops = 0;
1901   for (unsigned i = 0; i < minNumLoops; ++i) {
1902     if (loopsA[i] != loopsB[i])
1903       break;
1904     ++numCommonLoops;
1905   }
1906   return numCommonLoops;
1907 }
1908 
1909 static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
1910                                                       Block::iterator start,
1911                                                       Block::iterator end,
1912                                                       int memorySpace) {
1913   SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
1914 
1915   // Walk this 'affine.for' operation to gather all memory regions.
1916   auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
1917     if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
1918       // Neither load nor a store op.
1919       return WalkResult::advance();
1920     }
1921 
1922     // Compute the memref region symbolic in any IVs enclosing this block.
1923     auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
1924     if (failed(
1925             region->compute(opInst,
1926                             /*loopDepth=*/getNestingDepth(&*block.begin())))) {
1927       return opInst->emitError("error obtaining memory region\n");
1928     }
1929 
1930     auto [it, inserted] = regions.try_emplace(region->memref);
1931     if (inserted) {
1932       it->second = std::move(region);
1933     } else if (failed(it->second->unionBoundingBox(*region))) {
1934       return opInst->emitWarning(
1935           "getMemoryFootprintBytes: unable to perform a union on a memory "
1936           "region");
1937     }
1938     return WalkResult::advance();
1939   });
1940   if (result.wasInterrupted())
1941     return std::nullopt;
1942 
1943   int64_t totalSizeInBytes = 0;
1944   for (const auto &region : regions) {
1945     std::optional<int64_t> size = region.second->getRegionSize();
1946     if (!size.has_value())
1947       return std::nullopt;
1948     totalSizeInBytes += *size;
1949   }
1950   return totalSizeInBytes;
1951 }
1952 
1953 std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp,
1954                                                              int memorySpace) {
1955   auto *forInst = forOp.getOperation();
1956   return ::getMemoryFootprintBytes(
1957       *forInst->getBlock(), Block::iterator(forInst),
1958       std::next(Block::iterator(forInst)), memorySpace);
1959 }
1960 
1961 /// Returns whether a loop is parallel and contains a reduction loop.
1962 bool mlir::affine::isLoopParallelAndContainsReduction(AffineForOp forOp) {
1963   SmallVector<LoopReduction> reductions;
1964   if (!isLoopParallel(forOp, &reductions))
1965     return false;
1966   return !reductions.empty();
1967 }
1968 
1969 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
1970 /// at 'forOp'.
1971 void mlir::affine::getSequentialLoops(
1972     AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
1973   forOp->walk([&](Operation *op) {
1974     if (auto innerFor = dyn_cast<AffineForOp>(op))
1975       if (!isLoopParallel(innerFor))
1976         sequentialLoops->insert(innerFor.getInductionVar());
1977   });
1978 }
1979 
1980 IntegerSet mlir::affine::simplifyIntegerSet(IntegerSet set) {
1981   FlatAffineValueConstraints fac(set);
1982   if (fac.isEmpty())
1983     return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
1984                                    set.getContext());
1985   fac.removeTrivialRedundancy();
1986 
1987   auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
1988   assert(simplifiedSet && "guaranteed to succeed while roundtripping");
1989   return simplifiedSet;
1990 }
1991 
1992 static void unpackOptionalValues(ArrayRef<std::optional<Value>> source,
1993                                  SmallVector<Value> &target) {
1994   target =
1995       llvm::to_vector<4>(llvm::map_range(source, [](std::optional<Value> val) {
1996         return val.has_value() ? *val : Value();
1997       }));
1998 }
1999 
2000 /// Bound an identifier `pos` in a given FlatAffineValueConstraints with
2001 /// constraints drawn from an affine map. Before adding the constraint, the
2002 /// dimensions/symbols of the affine map are aligned with `constraints`.
2003 /// `operands` are the SSA Value operands used with the affine map.
2004 /// Note: This function adds a new symbol column to the `constraints` for each
2005 /// dimension/symbol that exists in the affine map but not in `constraints`.
2006 static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
2007                                       BoundType type, unsigned pos,
2008                                       AffineMap map, ValueRange operands) {
2009   SmallVector<Value> dims, syms, newSyms;
2010   unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims);
2011   unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms);
2012 
2013   AffineMap alignedMap =
2014       alignAffineMapWithValues(map, operands, dims, syms, &newSyms);
2015   for (unsigned i = syms.size(); i < newSyms.size(); ++i)
2016     constraints.appendSymbolVar(newSyms[i]);
2017   return constraints.addBound(type, pos, alignedMap);
2018 }
2019 
2020 /// Add `val` to each result of `map`.
2021 static AffineMap addConstToResults(AffineMap map, int64_t val) {
2022   SmallVector<AffineExpr> newResults;
2023   for (AffineExpr r : map.getResults())
2024     newResults.push_back(r + val);
2025   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
2026                         map.getContext());
2027 }
2028 
2029 // Attempt to simplify the given min/max operation by proving that its value is
2030 // bounded by the same lower and upper bound.
2031 //
2032 // Bounds are computed by FlatAffineValueConstraints. Invariants required for
2033 // finding/proving bounds should be supplied via `constraints`.
2034 //
2035 // 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
2036 // 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
2037 //    case of `!isMin`) and bind it to `opBound`. SSA values that are used in
2038 //    `op` but are not part of `constraints`, are added as extra symbols.
2039 // 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
2040 //    * If `isMin`: r_i >= opBound
2041 //    * If `isMax`: r_i <= opBound
2042 //    If this is the case, ub(op) == lb(op).
2043 // 4. Replace `op` with `opBound`.
2044 //
2045 // In summary, the following constraints are added throughout this function.
2046 // Note: `invar` are dimensions added by the caller to express the invariants.
2047 // (Showing only the case where `isMin`.)
2048 //
2049 //  invar |    op | opBound | r_i | extra syms... | const |           eq/ineq
2050 //  ------+-------+---------+-----+---------------+-------+-------------------
2051 //   (various eq./ineq. constraining `invar`, added by the caller)
2052 //    ... |     0 |       0 |   0 |             0 |   ... |               ...
2053 //  ------+-------+---------+-----+---------------+-------+-------------------
2054 //  (various ineq. constraining `op` in terms of `op` operands (`invar` and
2055 //    extra `op` operands "extra syms" that are not in `invar`)).
2056 //    ... |    -1 |       0 |   0 |           ... |   ... |              >= 0
2057 //  ------+-------+---------+-----+---------------+-------+-------------------
2058 //   (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
2059 //    ... |     0 |      -1 |   0 |           ... |   ... |               = 0
2060 //  ------+-------+---------+-----+---------------+-------+-------------------
2061 //   (for each `op` map result r_i: set r_i to corresponding map result,
2062 //    prove that r_i >= minOpUb via contradiction)
2063 //    ... |     0 |       0 |  -1 |           ... |   ... |               = 0
2064 //      0 |     0 |       1 |  -1 |             0 |    -1 |              >= 0
2065 //
2066 FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
2067     Operation *op, FlatAffineValueConstraints constraints) {
2068   bool isMin = isa<AffineMinOp>(op);
2069   assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
2070   MLIRContext *ctx = op->getContext();
2071   Builder builder(ctx);
2072   AffineMap map =
2073       isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
2074   ValueRange operands = op->getOperands();
2075   unsigned numResults = map.getNumResults();
2076 
2077   // Add a few extra dimensions.
2078   unsigned dimOp = constraints.appendDimVar();      // `op`
2079   unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
2080   unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
2081 
2082   // Add an inequality for each result expr_i of map:
2083   // isMin: op <= expr_i, !isMin: op >= expr_i
2084   auto boundType = isMin ? BoundType::UB : BoundType::LB;
2085   // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
2086   AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map;
2087   if (failed(
2088           alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands)))
2089     return failure();
2090 
2091   // Try to compute a lower/upper bound for op, expressed in terms of the other
2092   // `dims` and extra symbols.
2093   SmallVector<AffineMap> opLb(1), opUb(1);
2094   constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb);
2095   AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
2096   // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
2097   // a TODO of `getSliceBounds` and not handled here.
2098   if (!sliceBound || sliceBound.getNumResults() != 1)
2099     return failure(); // No or multiple bounds found.
2100   // Recover the inclusive UB in the case of an `affine.min`.
2101   AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound;
2102 
2103   // Add an equality: Set dimOpBound to computed bound.
2104   // Add back dimension for op. (Was removed by `getSliceBounds`.)
2105   AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
2106   if (failed(constraints.addBound(BoundType::EQ, dimOpBound, alignedBoundMap)))
2107     return failure();
2108 
2109   // If the constraint system is empty, there is an inconsistency. (E.g., this
2110   // can happen if loop lb > ub.)
2111   if (constraints.isEmpty())
2112     return failure();
2113 
2114   // In the case of `isMin` (`!isMin` is inversed):
2115   // Prove that each result of `map` has a lower bound that is equal to (or
2116   // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
2117   // can be replaced with the bound. I.e., prove that for each result
2118   // expr_i (represented by dimension r_i):
2119   //
2120   // r_i >= opBound
2121   //
2122   // To prove this inequality, add its negation to the constraint set and prove
2123   // that the constraint set is empty.
2124   for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
2125     FlatAffineValueConstraints newConstr(constraints);
2126 
2127     // Add an equality: r_i = expr_i
2128     // Note: These equalities could have been added earlier and used to express
2129     // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
2130     // computes minOpUb in terms of r_i dims, which is not desired.
2131     if (failed(alignAndAddBound(newConstr, BoundType::EQ, i,
2132                                 map.getSubMap({i - resultDimStart}), operands)))
2133       return failure();
2134 
2135     // If `isMin`:  Add inequality: r_i < opBound
2136     //              equiv.: opBound - r_i - 1 >= 0
2137     // If `!isMin`: Add inequality: r_i > opBound
2138     //              equiv.: -opBound + r_i - 1 >= 0
2139     SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
2140     ineq[dimOpBound] = isMin ? 1 : -1;
2141     ineq[i] = isMin ? -1 : 1;
2142     ineq[newConstr.getNumCols() - 1] = -1;
2143     newConstr.addInequality(ineq);
2144     if (!newConstr.isEmpty())
2145       return failure();
2146   }
2147 
2148   // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
2149   AffineMap newMap = alignedBoundMap;
2150   SmallVector<Value> newOperands;
2151   unpackOptionalValues(constraints.getMaybeValues(), newOperands);
2152   // If dims/symbols have known constant values, use those in order to simplify
2153   // the affine map further.
2154   for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
2155     // Skip unused operands and operands that are already constants.
2156     if (!newOperands[i] || getConstantIntValue(newOperands[i]))
2157       continue;
2158     if (auto bound = constraints.getConstantBound64(BoundType::EQ, i)) {
2159       AffineExpr expr =
2160           i < newMap.getNumDims()
2161               ? builder.getAffineDimExpr(i)
2162               : builder.getAffineSymbolExpr(i - newMap.getNumDims());
2163       newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound),
2164                               newMap.getNumDims(), newMap.getNumSymbols());
2165     }
2166   }
2167   affine::canonicalizeMapAndOperands(&newMap, &newOperands);
2168   return AffineValueMap(newMap, newOperands);
2169 }
2170