xref: /llvm-project/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (revision 9f2dd085ae981740e2986a1b200ca2a7df44953d)
1 //===- ValueBoundsOpInterface.cpp - Value Bounds  -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
10 
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
14 #include "mlir/Interfaces/ViewLikeInterface.h"
15 #include "llvm/ADT/APSInt.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "value-bounds-op-interface"
19 
20 using namespace mlir;
21 using presburger::BoundType;
22 using presburger::VarKind;
23 
24 namespace mlir {
25 #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
26 } // namespace mlir
27 
28 static Operation *getOwnerOfValue(Value value) {
29   if (auto bbArg = dyn_cast<BlockArgument>(value))
30     return bbArg.getOwner()->getParentOp();
31   return value.getDefiningOp();
32 }
33 
34 HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
35                                              ArrayRef<OpFoldResult> sizes,
36                                              ArrayRef<OpFoldResult> strides)
37     : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
38   assert(offsets.size() == sizes.size() &&
39          "expected same number of offsets, sizes, strides");
40   assert(offsets.size() == strides.size() &&
41          "expected same number of offsets, sizes, strides");
42 }
43 
44 HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
45                                              ArrayRef<OpFoldResult> sizes)
46     : mixedOffsets(offsets), mixedSizes(sizes) {
47   assert(offsets.size() == sizes.size() &&
48          "expected same number of offsets and sizes");
49   // Assume that all strides are 1.
50   if (offsets.empty())
51     return;
52   MLIRContext *ctx = offsets.front().getContext();
53   mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
54 }
55 
56 HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op)
57     : HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(),
58                             op.getMixedStrides()) {}
59 
60 /// If ofr is a constant integer or an IntegerAttr, return the integer.
61 static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
62   // Case 1: Check for Constant integer.
63   if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
64     APSInt intVal;
65     if (matchPattern(val, m_ConstantInt(&intVal)))
66       return intVal.getSExtValue();
67     return std::nullopt;
68   }
69   // Case 2: Check for IntegerAttr.
70   Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
71   if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
72     return intAttr.getValue().getSExtValue();
73   return std::nullopt;
74 }
75 
76 ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr)
77     : Variable(ofr, std::nullopt) {}
78 
79 ValueBoundsConstraintSet::Variable::Variable(Value indexValue)
80     : Variable(static_cast<OpFoldResult>(indexValue)) {}
81 
82 ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim)
83     : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
84 
85 ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr,
86                                              std::optional<int64_t> dim) {
87   Builder b(ofr.getContext());
88   if (auto constInt = ::getConstantIntValue(ofr)) {
89     assert(!dim && "expected no dim for index-typed values");
90     map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
91                          b.getAffineConstantExpr(*constInt));
92     return;
93   }
94   Value value = cast<Value>(ofr);
95 #ifndef NDEBUG
96   if (dim) {
97     assert(isa<ShapedType>(value.getType()) && "expected shaped type");
98   } else {
99     assert(value.getType().isIndex() && "expected index type");
100   }
101 #endif // NDEBUG
102   map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
103                        b.getAffineSymbolExpr(0));
104   mapOperands.emplace_back(value, dim);
105 }
106 
107 ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
108                                              ArrayRef<Variable> mapOperands) {
109   assert(map.getNumResults() == 1 && "expected single result");
110 
111   // Turn all dims into symbols.
112   Builder b(map.getContext());
113   SmallVector<AffineExpr> dimReplacements, symReplacements;
114   for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
115     dimReplacements.push_back(b.getAffineSymbolExpr(i));
116   for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i)
117     symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
118   AffineMap tmpMap = map.replaceDimsAndSymbols(
119       dimReplacements, symReplacements, /*numResultDims=*/0,
120       /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
121 
122   // Inline operands.
123   DenseMap<AffineExpr, AffineExpr> replacements;
124   for (auto [index, var] : llvm::enumerate(mapOperands)) {
125     assert(var.map.getNumResults() == 1 && "expected single result");
126     assert(var.map.getNumDims() == 0 && "expected only symbols");
127     SmallVector<AffineExpr> symReplacements;
128     for (auto valueDim : var.mapOperands) {
129       auto it = llvm::find(this->mapOperands, valueDim);
130       if (it != this->mapOperands.end()) {
131         // There is already a symbol for this operand.
132         symReplacements.push_back(b.getAffineSymbolExpr(
133             std::distance(this->mapOperands.begin(), it)));
134       } else {
135         // This is a new operand: add a new symbol.
136         symReplacements.push_back(
137             b.getAffineSymbolExpr(this->mapOperands.size()));
138         this->mapOperands.push_back(valueDim);
139       }
140     }
141     replacements[b.getAffineSymbolExpr(index)] =
142         var.map.getResult(0).replaceSymbols(symReplacements);
143   }
144   this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
145                              /*numResultSyms=*/this->mapOperands.size());
146 }
147 
148 ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
149                                              ArrayRef<Value> mapOperands)
150     : Variable(map, llvm::map_to_vector(mapOperands,
151                                         [](Value v) { return Variable(v); })) {}
152 
153 ValueBoundsConstraintSet::ValueBoundsConstraintSet(
154     MLIRContext *ctx, StopConditionFn stopCondition,
155     bool addConservativeSemiAffineBounds)
156     : builder(ctx), stopCondition(stopCondition),
157       addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) {
158   assert(stopCondition && "expected non-null stop condition");
159 }
160 
161 char ValueBoundsConstraintSet::ID = 0;
162 
163 #ifndef NDEBUG
164 static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
165   if (value.getType().isIndex()) {
166     assert(!dim.has_value() && "invalid dim value");
167   } else if (auto shapedType = dyn_cast<ShapedType>(value.getType())) {
168     assert(*dim >= 0 && "invalid dim value");
169     if (shapedType.hasRank())
170       assert(*dim < shapedType.getRank() && "invalid dim value");
171   } else {
172     llvm_unreachable("unsupported type");
173   }
174 }
175 #endif // NDEBUG
176 
177 void ValueBoundsConstraintSet::addBound(BoundType type, int64_t pos,
178                                         AffineExpr expr) {
179   // Note: If `addConservativeSemiAffineBounds` is true then the bound
180   // computation function needs to handle the case that the constraints set
181   // could become empty. This is because the conservative bounds add assumptions
182   // (e.g. for `mod` it assumes `rhs > 0`). If these constraints are later found
183   // not to hold, then the bound is invalid.
184   LogicalResult status = cstr.addBound(
185       type, pos,
186       AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr),
187       addConservativeSemiAffineBounds
188           ? FlatLinearConstraints::AddConservativeSemiAffineBounds::Yes
189           : FlatLinearConstraints::AddConservativeSemiAffineBounds::No);
190   if (failed(status)) {
191     // Not all semi-affine expressions are not yet supported by
192     // FlatLinearConstraints. However, we can just ignore such failures here.
193     // Even without this bound, there may be enough information in the
194     // constraint system to compute the requested bound. In case this bound is
195     // actually needed, `computeBound` will return `failure`.
196     LLVM_DEBUG(llvm::dbgs() << "Failed to add bound: " << expr << "\n");
197   }
198 }
199 
200 AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
201                                              std::optional<int64_t> dim) {
202 #ifndef NDEBUG
203   assertValidValueDim(value, dim);
204 #endif // NDEBUG
205 
206   // Check if the value/dim is statically known. In that case, an affine
207   // constant expression should be returned. This allows us to support
208   // multiplications with constants. (Multiplications of two columns in the
209   // constraint set is not supported.)
210   std::optional<int64_t> constSize = std::nullopt;
211   auto shapedType = dyn_cast<ShapedType>(value.getType());
212   if (shapedType) {
213     if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
214       constSize = shapedType.getDimSize(*dim);
215   } else if (auto constInt = ::getConstantIntValue(value)) {
216     constSize = *constInt;
217   }
218 
219   // If the value/dim is already mapped, return the corresponding expression
220   // directly.
221   ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
222   if (valueDimToPosition.contains(valueDim)) {
223     // If it is a constant, return an affine constant expression. Otherwise,
224     // return an affine expression that represents the respective column in the
225     // constraint set.
226     if (constSize)
227       return builder.getAffineConstantExpr(*constSize);
228     return getPosExpr(getPos(value, dim));
229   }
230 
231   if (constSize) {
232     // Constant index value/dim: add column to the constraint set, add EQ bound
233     // and return an affine constant expression without pushing the newly added
234     // column to the worklist.
235     (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
236     if (shapedType)
237       bound(value)[*dim] == *constSize;
238     else
239       bound(value) == *constSize;
240     return builder.getAffineConstantExpr(*constSize);
241   }
242 
243   // Dynamic value/dim: insert column to the constraint set and put it on the
244   // worklist. Return an affine expression that represents the newly inserted
245   // column in the constraint set.
246   return getPosExpr(insert(value, dim, /*isSymbol=*/true));
247 }
248 
249 AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
250   if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
251     return getExpr(value, /*dim=*/std::nullopt);
252   auto constInt = ::getConstantIntValue(ofr);
253   assert(constInt.has_value() && "expected Integer constant");
254   return builder.getAffineConstantExpr(*constInt);
255 }
256 
257 AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
258   return builder.getAffineConstantExpr(constant);
259 }
260 
261 int64_t ValueBoundsConstraintSet::insert(Value value,
262                                          std::optional<int64_t> dim,
263                                          bool isSymbol, bool addToWorklist) {
264 #ifndef NDEBUG
265   assertValidValueDim(value, dim);
266 #endif // NDEBUG
267 
268   ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
269   assert(!valueDimToPosition.contains(valueDim) && "already mapped");
270   int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
271                          : cstr.appendVar(VarKind::SetDim);
272   LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
273                           << " for: " << value
274                           << " (dim: " << dim.value_or(kIndexValue)
275                           << ", owner: " << getOwnerOfValue(value)->getName()
276                           << ")\n");
277   positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
278   // Update reverse mapping.
279   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
280     if (positionToValueDim[i].has_value())
281       valueDimToPosition[*positionToValueDim[i]] = i;
282 
283   if (addToWorklist) {
284     LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
285                             << " (dim: " << dim.value_or(kIndexValue) << ")\n");
286     worklist.push(pos);
287   }
288 
289   return pos;
290 }
291 
292 int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
293   int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
294                          : cstr.appendVar(VarKind::SetDim);
295   LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
296                           << "\n");
297   positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
298   // Update reverse mapping.
299   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
300     if (positionToValueDim[i].has_value())
301       valueDimToPosition[*positionToValueDim[i]] = i;
302   return pos;
303 }
304 
305 int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
306                                          bool isSymbol) {
307   assert(map.getNumResults() == 1 && "expected affine map with one result");
308   int64_t pos = insert(isSymbol);
309 
310   // Add map and operands to the constraint set. Dimensions are converted to
311   // symbols. All operands are added to the worklist (unless they were already
312   // processed).
313   auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
314     return getExpr(v.first, v.second);
315   };
316   SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
317       llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
318   SmallVector<AffineExpr> symReplacements = llvm::to_vector(
319       llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
320   addBound(
321       presburger::BoundType::EQ, pos,
322       map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
323 
324   return pos;
325 }
326 
327 int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) {
328   return insert(var.map, var.mapOperands, isSymbol);
329 }
330 
331 int64_t ValueBoundsConstraintSet::getPos(Value value,
332                                          std::optional<int64_t> dim) const {
333 #ifndef NDEBUG
334   assertValidValueDim(value, dim);
335   assert((isa<OpResult>(value) ||
336           cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
337          "unstructured control flow is not supported");
338 #endif // NDEBUG
339   LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
340                           << " (dim: " << dim.value_or(kIndexValue)
341                           << ", owner: " << getOwnerOfValue(value)->getName()
342                           << ")\n");
343   auto it =
344       valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
345   assert(it != valueDimToPosition.end() && "expected mapped entry");
346   return it->second;
347 }
348 
349 AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
350   assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
351   return pos < cstr.getNumDimVars()
352              ? builder.getAffineDimExpr(pos)
353              : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
354 }
355 
356 bool ValueBoundsConstraintSet::isMapped(Value value,
357                                         std::optional<int64_t> dim) const {
358   auto it =
359       valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
360   return it != valueDimToPosition.end();
361 }
362 
363 void ValueBoundsConstraintSet::processWorklist() {
364   LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
365   while (!worklist.empty()) {
366     int64_t pos = worklist.front();
367     worklist.pop();
368     assert(positionToValueDim[pos].has_value() &&
369            "did not expect std::nullopt on worklist");
370     ValueDim valueDim = *positionToValueDim[pos];
371     Value value = valueDim.first;
372     int64_t dim = valueDim.second;
373 
374     // Check for static dim size.
375     if (dim != kIndexValue) {
376       auto shapedType = cast<ShapedType>(value.getType());
377       if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) {
378         bound(value)[dim] == getExpr(shapedType.getDimSize(dim));
379         continue;
380       }
381     }
382 
383     // Do not process any further if the stop condition is met.
384     auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
385     if (stopCondition(value, maybeDim, *this)) {
386       LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
387                               << " (dim: " << maybeDim << ")\n");
388       continue;
389     }
390 
391     // Query `ValueBoundsOpInterface` for constraints. New items may be added to
392     // the worklist.
393     auto valueBoundsOp =
394         dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
395     LLVM_DEBUG(llvm::dbgs()
396                << "Query value bounds for: " << value
397                << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
398     if (valueBoundsOp) {
399       if (dim == kIndexValue) {
400         valueBoundsOp.populateBoundsForIndexValue(value, *this);
401       } else {
402         valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
403       }
404       continue;
405     }
406     LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
407 
408     // If the op does not implement `ValueBoundsOpInterface`, check if it
409     // implements the `DestinationStyleOpInterface`. OpResults of such ops are
410     // tied to OpOperands. Tied values have the same shape.
411     auto dstOp = value.getDefiningOp<DestinationStyleOpInterface>();
412     if (!dstOp || dim == kIndexValue)
413       continue;
414     Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(value))->get();
415     bound(value)[dim] == getExpr(tiedOperand, dim);
416   }
417 }
418 
419 void ValueBoundsConstraintSet::projectOut(int64_t pos) {
420   assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) &&
421          "invalid position");
422   cstr.projectOut(pos);
423   if (positionToValueDim[pos].has_value()) {
424     bool erased = valueDimToPosition.erase(*positionToValueDim[pos]);
425     (void)erased;
426     assert(erased && "inconsistent reverse mapping");
427   }
428   positionToValueDim.erase(positionToValueDim.begin() + pos);
429   // Update reverse mapping.
430   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
431     if (positionToValueDim[i].has_value())
432       valueDimToPosition[*positionToValueDim[i]] = i;
433 }
434 
435 void ValueBoundsConstraintSet::projectOut(
436     function_ref<bool(ValueDim)> condition) {
437   int64_t nextPos = 0;
438   while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
439     if (positionToValueDim[nextPos].has_value() &&
440         condition(*positionToValueDim[nextPos])) {
441       projectOut(nextPos);
442       // The column was projected out so another column is now at that position.
443       // Do not increase the counter.
444     } else {
445       ++nextPos;
446     }
447   }
448 }
449 
450 void ValueBoundsConstraintSet::projectOutAnonymous(
451     std::optional<int64_t> except) {
452   int64_t nextPos = 0;
453   while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
454     if (positionToValueDim[nextPos].has_value() || except == nextPos) {
455       ++nextPos;
456     } else {
457       projectOut(nextPos);
458       // The column was projected out so another column is now at that position.
459       // Do not increase the counter.
460     }
461   }
462 }
463 
464 LogicalResult ValueBoundsConstraintSet::computeBound(
465     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
466     const Variable &var, StopConditionFn stopCondition, bool closedUB) {
467   MLIRContext *ctx = var.getContext();
468   int64_t ubAdjustment = closedUB ? 0 : 1;
469   Builder b(ctx);
470   mapOperands.clear();
471 
472   // Process the backward slice of `value` (i.e., reverse use-def chain) until
473   // `stopCondition` is met.
474   ValueBoundsConstraintSet cstr(ctx, stopCondition);
475   int64_t pos = cstr.insert(var, /*isSymbol=*/false);
476   assert(pos == 0 && "expected first column");
477   cstr.processWorklist();
478 
479   // Project out all variables (apart from `valueDim`) that do not match the
480   // stop condition.
481   cstr.projectOut([&](ValueDim p) {
482     auto maybeDim =
483         p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
484     return !stopCondition(p.first, maybeDim, cstr);
485   });
486   cstr.projectOutAnonymous(/*except=*/pos);
487 
488   // Compute lower and upper bounds for `valueDim`.
489   SmallVector<AffineMap> lb(1), ub(1);
490   cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
491                            /*closedUB=*/true);
492 
493   // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
494   // case, no lower/upper bound can be computed at the moment.
495   // EQ, UB bounds: upper bound is needed.
496   if ((type != BoundType::LB) &&
497       (ub.empty() || !ub[0] || ub[0].getNumResults() == 0))
498     return failure();
499   // EQ, LB bounds: lower bound is needed.
500   if ((type != BoundType::UB) &&
501       (lb.empty() || !lb[0] || lb[0].getNumResults() == 0))
502     return failure();
503 
504   // TODO: Generate an affine map with multiple results.
505   if (type != BoundType::LB)
506     assert(ub.size() == 1 && ub[0].getNumResults() == 1 &&
507            "multiple bounds not supported");
508   if (type != BoundType::UB)
509     assert(lb.size() == 1 && lb[0].getNumResults() == 1 &&
510            "multiple bounds not supported");
511 
512   // EQ bound: lower and upper bound must match.
513   if (type == BoundType::EQ && ub[0] != lb[0])
514     return failure();
515 
516   AffineMap bound;
517   if (type == BoundType::EQ || type == BoundType::LB) {
518     bound = lb[0];
519   } else {
520     // Computed UB is a closed bound.
521     bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(),
522                            ub[0].getResult(0) + ubAdjustment);
523   }
524 
525   // Gather all SSA values that are used in the computed bound.
526   assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() &&
527          "inconsistent mapping state");
528   SmallVector<AffineExpr> replacementDims, replacementSymbols;
529   int64_t numDims = 0, numSymbols = 0;
530   for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) {
531     // Skip `value`.
532     if (i == pos)
533       continue;
534     // Check if the position `i` is used in the generated bound. If so, it must
535     // be included in the generated affine.apply op.
536     bool used = false;
537     bool isDim = i < cstr.cstr.getNumDimVars();
538     if (isDim) {
539       if (bound.isFunctionOfDim(i))
540         used = true;
541     } else {
542       if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
543         used = true;
544     }
545 
546     if (!used) {
547       // Not used: Remove dim/symbol from the result.
548       if (isDim) {
549         replacementDims.push_back(b.getAffineConstantExpr(0));
550       } else {
551         replacementSymbols.push_back(b.getAffineConstantExpr(0));
552       }
553       continue;
554     }
555 
556     if (isDim) {
557       replacementDims.push_back(b.getAffineDimExpr(numDims++));
558     } else {
559       replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++));
560     }
561 
562     assert(cstr.positionToValueDim[i].has_value() &&
563            "cannot build affine map in terms of anonymous column");
564     ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i];
565     Value value = valueDim.first;
566     int64_t dim = valueDim.second;
567     if (dim == ValueBoundsConstraintSet::kIndexValue) {
568       // An index-type value is used: can be used directly in the affine.apply
569       // op.
570       assert(value.getType().isIndex() && "expected index type");
571       mapOperands.push_back(std::make_pair(value, std::nullopt));
572       continue;
573     }
574 
575     assert(cast<ShapedType>(value.getType()).isDynamicDim(dim) &&
576            "expected dynamic dim");
577     mapOperands.push_back(std::make_pair(value, dim));
578   }
579 
580   resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols,
581                                           numDims, numSymbols);
582   return success();
583 }
584 
585 LogicalResult ValueBoundsConstraintSet::computeDependentBound(
586     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
587     const Variable &var, ValueDimList dependencies, bool closedUB) {
588   return computeBound(
589       resultMap, mapOperands, type, var,
590       [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
591         return llvm::is_contained(dependencies, std::make_pair(v, d));
592       },
593       closedUB);
594 }
595 
596 LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
597     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
598     const Variable &var, ValueRange independencies, bool closedUB) {
599   // Return "true" if the given value is independent of all values in
600   // `independencies`. I.e., neither the value itself nor any value in the
601   // backward slice (reverse use-def chain) is contained in `independencies`.
602   auto isIndependent = [&](Value v) {
603     SmallVector<Value> worklist;
604     DenseSet<Value> visited;
605     worklist.push_back(v);
606     while (!worklist.empty()) {
607       Value next = worklist.pop_back_val();
608       if (!visited.insert(next).second)
609         continue;
610       if (llvm::is_contained(independencies, next))
611         return false;
612       // TODO: DominanceInfo could be used to stop the traversal early.
613       Operation *op = next.getDefiningOp();
614       if (!op)
615         continue;
616       worklist.append(op->getOperands().begin(), op->getOperands().end());
617     }
618     return true;
619   };
620 
621   // Reify bounds in terms of any independent values.
622   return computeBound(
623       resultMap, mapOperands, type, var,
624       [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
625         return isIndependent(v);
626       },
627       closedUB);
628 }
629 
630 FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
631     presburger::BoundType type, const Variable &var,
632     StopConditionFn stopCondition, bool closedUB) {
633   // Default stop condition if none was specified: Keep adding constraints until
634   // a bound could be computed.
635   int64_t pos = 0;
636   auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
637                                   ValueBoundsConstraintSet &cstr) {
638     return cstr.cstr.getConstantBound64(type, pos).has_value();
639   };
640 
641   ValueBoundsConstraintSet cstr(
642       var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
643   pos = cstr.populateConstraints(var.map, var.mapOperands);
644   assert(pos == 0 && "expected `map` is the first column");
645 
646   // Compute constant bound for `valueDim`.
647   int64_t ubAdjustment = closedUB ? 0 : 1;
648   if (auto bound = cstr.cstr.getConstantBound64(type, pos))
649     return type == BoundType::UB ? *bound + ubAdjustment : *bound;
650   return failure();
651 }
652 
653 void ValueBoundsConstraintSet::populateConstraints(Value value,
654                                                    std::optional<int64_t> dim) {
655 #ifndef NDEBUG
656   assertValidValueDim(value, dim);
657 #endif // NDEBUG
658 
659   // `getExpr` pushes the value/dim onto the worklist (unless it was already
660   // analyzed).
661   (void)getExpr(value, dim);
662   // Process all values/dims on the worklist. This may traverse and analyze
663   // additional IR, depending the current stop function.
664   processWorklist();
665 }
666 
667 int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
668                                                       ValueDimList operands) {
669   int64_t pos = insert(map, operands, /*isSymbol=*/false);
670   // Process the backward slice of `operands` (i.e., reverse use-def chain)
671   // until `stopCondition` is met.
672   processWorklist();
673   return pos;
674 }
675 
676 FailureOr<int64_t>
677 ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
678                                                std::optional<int64_t> dim1,
679                                                std::optional<int64_t> dim2) {
680 #ifndef NDEBUG
681   assertValidValueDim(value1, dim1);
682   assertValidValueDim(value2, dim2);
683 #endif // NDEBUG
684 
685   Builder b(value1.getContext());
686   AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
687                                  b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
688   return computeConstantBound(presburger::BoundType::EQ,
689                               Variable(map, {{value1, dim1}, {value2, dim2}}));
690 }
691 
692 bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
693                                           ComparisonOperator cmp,
694                                           int64_t rhsPos) {
695   // This function returns "true" if "lhs CMP rhs" is proven to hold.
696   //
697   // Example for ComparisonOperator::LE and index-typed values: We would like to
698   // prove that lhs <= rhs. Proof by contradiction: add the inverse
699   // relation (lhs > rhs) to the constraint set and check if the resulting
700   // constraint set is "empty" (i.e. has no solution). In that case,
701   // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
702 
703   // We cannot prove anything if the constraint set is already empty.
704   if (cstr.isEmpty()) {
705     LLVM_DEBUG(
706         llvm::dbgs()
707         << "cannot compare value/dims: constraint system is already empty");
708     return false;
709   }
710 
711   // EQ can be expressed as LE and GE.
712   if (cmp == EQ)
713     return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
714            comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
715 
716   // Construct inequality.
717   SmallVector<int64_t> eq(cstr.getNumCols(), 0);
718   if (cmp == LT || cmp == LE) {
719     ++eq[lhsPos];
720     --eq[rhsPos];
721   } else if (cmp == GT || cmp == GE) {
722     --eq[lhsPos];
723     ++eq[rhsPos];
724   } else {
725     llvm_unreachable("unsupported comparison operator");
726   }
727   if (cmp == LE || cmp == GE)
728     eq[cstr.getNumCols() - 1] -= 1;
729 
730   // Add inequality to the constraint set and check if it made the constraint
731   // set empty.
732   int64_t ineqPos = cstr.getNumInequalities();
733   cstr.addInequality(eq);
734   bool isEmpty = cstr.isEmpty();
735   cstr.removeInequality(ineqPos);
736   return isEmpty;
737 }
738 
739 bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
740                                                   ComparisonOperator cmp,
741                                                   const Variable &rhs) {
742   int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
743   int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
744   return comparePos(lhsPos, cmp, rhsPos);
745 }
746 
747 bool ValueBoundsConstraintSet::compare(const Variable &lhs,
748                                        ComparisonOperator cmp,
749                                        const Variable &rhs) {
750   int64_t lhsPos = -1, rhsPos = -1;
751   auto stopCondition = [&](Value v, std::optional<int64_t> dim,
752                            ValueBoundsConstraintSet &cstr) {
753     // Keep processing as long as lhs/rhs were not processed.
754     if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
755         size_t(rhsPos) >= cstr.positionToValueDim.size())
756       return false;
757     // Keep processing as long as the relation cannot be proven.
758     return cstr.comparePos(lhsPos, cmp, rhsPos);
759   };
760   ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
761   lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
762   rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
763   return cstr.comparePos(lhsPos, cmp, rhsPos);
764 }
765 
766 FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
767                                                    const Variable &var2) {
768   if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
769     return true;
770   if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
771       ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
772     return false;
773   return failure();
774 }
775 
776 FailureOr<bool>
777 ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
778                                                HyperrectangularSlice slice1,
779                                                HyperrectangularSlice slice2) {
780   assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() &&
781          "expected slices of same rank");
782   assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() &&
783          "expected slices of same rank");
784   assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() &&
785          "expected slices of same rank");
786 
787   Builder b(ctx);
788   bool foundUnknownBound = false;
789   for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
790     AffineMap map =
791         AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
792                        b.getAffineSymbolExpr(0) +
793                            b.getAffineSymbolExpr(1) * b.getAffineSymbolExpr(2) -
794                            b.getAffineSymbolExpr(3));
795     {
796       // Case 1: Slices are guaranteed to be non-overlapping if
797       // offset1 + size1 * stride1 <= offset2 (for at least one dimension).
798       SmallVector<OpFoldResult> ofrOperands;
799       ofrOperands.push_back(slice1.getMixedOffsets()[i]);
800       ofrOperands.push_back(slice1.getMixedSizes()[i]);
801       ofrOperands.push_back(slice1.getMixedStrides()[i]);
802       ofrOperands.push_back(slice2.getMixedOffsets()[i]);
803       SmallVector<Value> valueOperands;
804       AffineMap foldedMap =
805           foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
806       FailureOr<int64_t> constBound = computeConstantBound(
807           presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
808       foundUnknownBound |= failed(constBound);
809       if (succeeded(constBound) && *constBound <= 0)
810         return false;
811     }
812     {
813       // Case 2: Slices are guaranteed to be non-overlapping if
814       // offset2 + size2 * stride2 <= offset1 (for at least one dimension).
815       SmallVector<OpFoldResult> ofrOperands;
816       ofrOperands.push_back(slice2.getMixedOffsets()[i]);
817       ofrOperands.push_back(slice2.getMixedSizes()[i]);
818       ofrOperands.push_back(slice2.getMixedStrides()[i]);
819       ofrOperands.push_back(slice1.getMixedOffsets()[i]);
820       SmallVector<Value> valueOperands;
821       AffineMap foldedMap =
822           foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
823       FailureOr<int64_t> constBound = computeConstantBound(
824           presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
825       foundUnknownBound |= failed(constBound);
826       if (succeeded(constBound) && *constBound <= 0)
827         return false;
828     }
829   }
830 
831   // If at least one bound could not be computed, we cannot be certain that the
832   // slices are really overlapping.
833   if (foundUnknownBound)
834     return failure();
835 
836   // All bounds could be computed and none of the above cases applied.
837   // Therefore, the slices are guaranteed to overlap.
838   return true;
839 }
840 
841 FailureOr<bool>
842 ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
843                                               HyperrectangularSlice slice1,
844                                               HyperrectangularSlice slice2) {
845   assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() &&
846          "expected slices of same rank");
847   assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() &&
848          "expected slices of same rank");
849   assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() &&
850          "expected slices of same rank");
851 
852   // The two slices are equivalent if all of their offsets, sizes and strides
853   // are equal. If equality cannot be determined for at least one of those
854   // values, equivalence cannot be determined and this function returns
855   // "failure".
856   for (auto [offset1, offset2] :
857        llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
858     FailureOr<bool> equal = areEqual(offset1, offset2);
859     if (failed(equal))
860       return failure();
861     if (!equal.value())
862       return false;
863   }
864   for (auto [size1, size2] :
865        llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
866     FailureOr<bool> equal = areEqual(size1, size2);
867     if (failed(equal))
868       return failure();
869     if (!equal.value())
870       return false;
871   }
872   for (auto [stride1, stride2] :
873        llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
874     FailureOr<bool> equal = areEqual(stride1, stride2);
875     if (failed(equal))
876       return failure();
877     if (!equal.value())
878       return false;
879   }
880   return true;
881 }
882 
883 void ValueBoundsConstraintSet::dump() const {
884   llvm::errs() << "==========\nColumns:\n";
885   llvm::errs() << "(column\tdim\tvalue)\n";
886   for (auto [index, valueDim] : llvm::enumerate(positionToValueDim)) {
887     llvm::errs() << " " << index << "\t";
888     if (valueDim) {
889       if (valueDim->second == kIndexValue) {
890         llvm::errs() << "n/a\t";
891       } else {
892         llvm::errs() << valueDim->second << "\t";
893       }
894       llvm::errs() << getOwnerOfValue(valueDim->first)->getName() << " ";
895       if (OpResult result = dyn_cast<OpResult>(valueDim->first)) {
896         llvm::errs() << "(result " << result.getResultNumber() << ")";
897       } else {
898         llvm::errs() << "(bbarg "
899                      << cast<BlockArgument>(valueDim->first).getArgNumber()
900                      << ")";
901       }
902       llvm::errs() << "\n";
903     } else {
904       llvm::errs() << "n/a\tn/a\n";
905     }
906   }
907   llvm::errs() << "\nConstraint set:\n";
908   cstr.dump();
909   llvm::errs() << "==========\n";
910 }
911 
912 ValueBoundsConstraintSet::BoundBuilder &
913 ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
914   assert(!this->dim.has_value() && "dim was already set");
915   this->dim = dim;
916 #ifndef NDEBUG
917   assertValidValueDim(value, this->dim);
918 #endif // NDEBUG
919   return *this;
920 }
921 
922 void ValueBoundsConstraintSet::BoundBuilder::operator<(AffineExpr expr) {
923 #ifndef NDEBUG
924   assertValidValueDim(value, this->dim);
925 #endif // NDEBUG
926   cstr.addBound(BoundType::UB, cstr.getPos(value, this->dim), expr);
927 }
928 
929 void ValueBoundsConstraintSet::BoundBuilder::operator<=(AffineExpr expr) {
930   operator<(expr + 1);
931 }
932 
933 void ValueBoundsConstraintSet::BoundBuilder::operator>(AffineExpr expr) {
934   operator>=(expr + 1);
935 }
936 
937 void ValueBoundsConstraintSet::BoundBuilder::operator>=(AffineExpr expr) {
938 #ifndef NDEBUG
939   assertValidValueDim(value, this->dim);
940 #endif // NDEBUG
941   cstr.addBound(BoundType::LB, cstr.getPos(value, this->dim), expr);
942 }
943 
944 void ValueBoundsConstraintSet::BoundBuilder::operator==(AffineExpr expr) {
945 #ifndef NDEBUG
946   assertValidValueDim(value, this->dim);
947 #endif // NDEBUG
948   cstr.addBound(BoundType::EQ, cstr.getPos(value, this->dim), expr);
949 }
950 
951 void ValueBoundsConstraintSet::BoundBuilder::operator<(OpFoldResult ofr) {
952   operator<(cstr.getExpr(ofr));
953 }
954 
955 void ValueBoundsConstraintSet::BoundBuilder::operator<=(OpFoldResult ofr) {
956   operator<=(cstr.getExpr(ofr));
957 }
958 
959 void ValueBoundsConstraintSet::BoundBuilder::operator>(OpFoldResult ofr) {
960   operator>(cstr.getExpr(ofr));
961 }
962 
963 void ValueBoundsConstraintSet::BoundBuilder::operator>=(OpFoldResult ofr) {
964   operator>=(cstr.getExpr(ofr));
965 }
966 
967 void ValueBoundsConstraintSet::BoundBuilder::operator==(OpFoldResult ofr) {
968   operator==(cstr.getExpr(ofr));
969 }
970 
971 void ValueBoundsConstraintSet::BoundBuilder::operator<(int64_t i) {
972   operator<(cstr.getExpr(i));
973 }
974 
975 void ValueBoundsConstraintSet::BoundBuilder::operator<=(int64_t i) {
976   operator<=(cstr.getExpr(i));
977 }
978 
979 void ValueBoundsConstraintSet::BoundBuilder::operator>(int64_t i) {
980   operator>(cstr.getExpr(i));
981 }
982 
983 void ValueBoundsConstraintSet::BoundBuilder::operator>=(int64_t i) {
984   operator>=(cstr.getExpr(i));
985 }
986 
987 void ValueBoundsConstraintSet::BoundBuilder::operator==(int64_t i) {
988   operator==(cstr.getExpr(i));
989 }
990