1 //===- ReifyValueBounds.cpp --- Reify value bounds with arith ops -------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arith/Transforms/Transforms.h"
10
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
15
16 using namespace mlir;
17 using namespace mlir::arith;
18
19 /// Build Arith IR for the given affine map and its operands.
buildArithValue(OpBuilder & b,Location loc,AffineMap map,ValueRange operands)20 static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
21 ValueRange operands) {
22 assert(map.getNumResults() == 1 && "multiple results not supported yet");
23 std::function<Value(AffineExpr)> buildExpr = [&](AffineExpr e) -> Value {
24 switch (e.getKind()) {
25 case AffineExprKind::Constant:
26 return b.create<ConstantIndexOp>(loc,
27 cast<AffineConstantExpr>(e).getValue());
28 case AffineExprKind::DimId:
29 return operands[cast<AffineDimExpr>(e).getPosition()];
30 case AffineExprKind::SymbolId:
31 return operands[cast<AffineSymbolExpr>(e).getPosition() +
32 map.getNumDims()];
33 case AffineExprKind::Add: {
34 auto binaryExpr = cast<AffineBinaryOpExpr>(e);
35 return b.create<AddIOp>(loc, buildExpr(binaryExpr.getLHS()),
36 buildExpr(binaryExpr.getRHS()));
37 }
38 case AffineExprKind::Mul: {
39 auto binaryExpr = cast<AffineBinaryOpExpr>(e);
40 return b.create<MulIOp>(loc, buildExpr(binaryExpr.getLHS()),
41 buildExpr(binaryExpr.getRHS()));
42 }
43 case AffineExprKind::FloorDiv: {
44 auto binaryExpr = cast<AffineBinaryOpExpr>(e);
45 return b.create<DivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
46 buildExpr(binaryExpr.getRHS()));
47 }
48 case AffineExprKind::CeilDiv: {
49 auto binaryExpr = cast<AffineBinaryOpExpr>(e);
50 return b.create<CeilDivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
51 buildExpr(binaryExpr.getRHS()));
52 }
53 case AffineExprKind::Mod: {
54 auto binaryExpr = cast<AffineBinaryOpExpr>(e);
55 return b.create<RemSIOp>(loc, buildExpr(binaryExpr.getLHS()),
56 buildExpr(binaryExpr.getRHS()));
57 }
58 }
59 llvm_unreachable("unsupported AffineExpr kind");
60 };
61 return buildExpr(map.getResult(0));
62 }
63
reifyValueBound(OpBuilder & b,Location loc,presburger::BoundType type,const ValueBoundsConstraintSet::Variable & var,ValueBoundsConstraintSet::StopConditionFn stopCondition,bool closedUB)64 FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
65 OpBuilder &b, Location loc, presburger::BoundType type,
66 const ValueBoundsConstraintSet::Variable &var,
67 ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
68 // Compute bound.
69 AffineMap boundMap;
70 ValueDimList mapOperands;
71 if (failed(ValueBoundsConstraintSet::computeBound(
72 boundMap, mapOperands, type, var, stopCondition, closedUB)))
73 return failure();
74
75 // Materialize tensor.dim/memref.dim ops.
76 SmallVector<Value> operands;
77 for (auto valueDim : mapOperands) {
78 Value value = valueDim.first;
79 std::optional<int64_t> dim = valueDim.second;
80
81 if (!dim.has_value()) {
82 // This is an index-typed value.
83 assert(value.getType().isIndex() && "expected index type");
84 operands.push_back(value);
85 continue;
86 }
87
88 assert(cast<ShapedType>(value.getType()).isDynamicDim(*dim) &&
89 "expected dynamic dim");
90 if (isa<RankedTensorType>(value.getType())) {
91 // A tensor dimension is used: generate a tensor.dim.
92 operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
93 } else if (isa<MemRefType>(value.getType())) {
94 // A memref dimension is used: generate a memref.dim.
95 operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
96 } else {
97 llvm_unreachable("cannot generate DimOp for unsupported shaped type");
98 }
99 }
100
101 // Check for special cases where no arith ops are needed.
102 if (boundMap.isSingleConstant()) {
103 // Bound is a constant: return an IntegerAttr.
104 return static_cast<OpFoldResult>(
105 b.getIndexAttr(boundMap.getSingleConstantResult()));
106 }
107 // No arith ops are needed if the bound is a single SSA value.
108 if (auto expr = dyn_cast<AffineDimExpr>(boundMap.getResult(0)))
109 return static_cast<OpFoldResult>(operands[expr.getPosition()]);
110 if (auto expr = dyn_cast<AffineSymbolExpr>(boundMap.getResult(0)))
111 return static_cast<OpFoldResult>(
112 operands[expr.getPosition() + boundMap.getNumDims()]);
113 // General case: build Arith ops.
114 return static_cast<OpFoldResult>(buildArithValue(b, loc, boundMap, operands));
115 }
116
reifyShapedValueDimBound(OpBuilder & b,Location loc,presburger::BoundType type,Value value,int64_t dim,ValueBoundsConstraintSet::StopConditionFn stopCondition,bool closedUB)117 FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
118 OpBuilder &b, Location loc, presburger::BoundType type, Value value,
119 int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
120 bool closedUB) {
121 auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
122 ValueBoundsConstraintSet &cstr) {
123 // We are trying to reify a bound for `value` in terms of the owning op's
124 // operands. Construct a stop condition that evaluates to "true" for any SSA
125 // value expect for `value`. I.e., the bound will be computed in terms of
126 // any SSA values expect for `value`. The first such values are operands of
127 // the owner of `value`.
128 return v != value;
129 };
130 return reifyValueBound(b, loc, type, {value, dim},
131 stopCondition ? stopCondition : reifyToOperands,
132 closedUB);
133 }
134
reifyIndexValueBound(OpBuilder & b,Location loc,presburger::BoundType type,Value value,ValueBoundsConstraintSet::StopConditionFn stopCondition,bool closedUB)135 FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
136 OpBuilder &b, Location loc, presburger::BoundType type, Value value,
137 ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
138 auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
139 ValueBoundsConstraintSet &cstr) {
140 return v != value;
141 };
142 return reifyValueBound(b, loc, type, value,
143 stopCondition ? stopCondition : reifyToOperands,
144 closedUB);
145 }
146