xref: /llvm-project/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp (revision 40dd3aa91d3f73184e34e45e597b84bec059c572)
1 //===- ReifyValueBounds.cpp --- Reify value bounds with affine ops ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
15 
16 using namespace mlir;
17 using namespace mlir::affine;
18 
reifyValueBound(OpBuilder & b,Location loc,presburger::BoundType type,const ValueBoundsConstraintSet::Variable & var,ValueBoundsConstraintSet::StopConditionFn stopCondition,bool closedUB)19 FailureOr<OpFoldResult> mlir::affine::reifyValueBound(
20     OpBuilder &b, Location loc, presburger::BoundType type,
21     const ValueBoundsConstraintSet::Variable &var,
22     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
23   // Compute bound.
24   AffineMap boundMap;
25   ValueDimList mapOperands;
26   if (failed(ValueBoundsConstraintSet::computeBound(
27           boundMap, mapOperands, type, var, stopCondition, closedUB)))
28     return failure();
29 
30   // Reify bound.
31   return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
32 }
33 
materializeComputedBound(OpBuilder & b,Location loc,AffineMap boundMap,ArrayRef<std::pair<Value,std::optional<int64_t>>> mapOperands)34 OpFoldResult affine::materializeComputedBound(
35     OpBuilder &b, Location loc, AffineMap boundMap,
36     ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands) {
37   // Materialize tensor.dim/memref.dim ops.
38   SmallVector<Value> operands;
39   for (auto valueDim : mapOperands) {
40     Value value = valueDim.first;
41     std::optional<int64_t> dim = valueDim.second;
42 
43     if (!dim.has_value()) {
44       // This is an index-typed value.
45       assert(value.getType().isIndex() && "expected index type");
46       operands.push_back(value);
47       continue;
48     }
49 
50     assert(cast<ShapedType>(value.getType()).isDynamicDim(*dim) &&
51            "expected dynamic dim");
52     if (isa<RankedTensorType>(value.getType())) {
53       // A tensor dimension is used: generate a tensor.dim.
54       operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
55     } else if (isa<MemRefType>(value.getType())) {
56       // A memref dimension is used: generate a memref.dim.
57       operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
58     } else {
59       llvm_unreachable("cannot generate DimOp for unsupported shaped type");
60     }
61   }
62 
63   // Simplify and return bound.
64   affine::canonicalizeMapAndOperands(&boundMap, &operands);
65   // Check for special cases where no affine.apply op is needed.
66   if (boundMap.isSingleConstant()) {
67     // Bound is a constant: return an IntegerAttr.
68     return static_cast<OpFoldResult>(
69         b.getIndexAttr(boundMap.getSingleConstantResult()));
70   }
71   // No affine.apply op is needed if the bound is a single SSA value.
72   if (auto expr = dyn_cast<AffineDimExpr>(boundMap.getResult(0)))
73     return static_cast<OpFoldResult>(operands[expr.getPosition()]);
74   if (auto expr = dyn_cast<AffineSymbolExpr>(boundMap.getResult(0)))
75     return static_cast<OpFoldResult>(
76         operands[expr.getPosition() + boundMap.getNumDims()]);
77   // General case: build affine.apply op.
78   return static_cast<OpFoldResult>(
79       b.create<affine::AffineApplyOp>(loc, boundMap, operands).getResult());
80 }
81 
reifyShapedValueDimBound(OpBuilder & b,Location loc,presburger::BoundType type,Value value,int64_t dim,ValueBoundsConstraintSet::StopConditionFn stopCondition,bool closedUB)82 FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
83     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
84     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
85     bool closedUB) {
86   auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
87                              ValueBoundsConstraintSet &cstr) {
88     // We are trying to reify a bound for `value` in terms of the owning op's
89     // operands. Construct a stop condition that evaluates to "true" for any SSA
90     // value except for `value`. I.e., the bound will be computed in terms of
91     // any SSA values except for `value`. The first such values are operands of
92     // the owner of `value`.
93     return v != value;
94   };
95   return reifyValueBound(b, loc, type, {value, dim},
96                          stopCondition ? stopCondition : reifyToOperands,
97                          closedUB);
98 }
99 
reifyIndexValueBound(OpBuilder & b,Location loc,presburger::BoundType type,Value value,ValueBoundsConstraintSet::StopConditionFn stopCondition,bool closedUB)100 FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
101     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
102     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
103   auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
104                              ValueBoundsConstraintSet &cstr) {
105     return v != value;
106   };
107   return reifyValueBound(b, loc, type, value,
108                          stopCondition ? stopCondition : reifyToOperands,
109                          closedUB);
110 }
111