xref: /llvm-project/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp (revision 40dd3aa91d3f73184e34e45e597b84bec059c572)
1 //===- ReifyValueBounds.cpp --- Reify value bounds with arith ops -------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/Transforms/Transforms.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
15 
16 using namespace mlir;
17 using namespace mlir::arith;
18 
19 /// Build Arith IR for the given affine map and its operands.
buildArithValue(OpBuilder & b,Location loc,AffineMap map,ValueRange operands)20 static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
21                              ValueRange operands) {
22   assert(map.getNumResults() == 1 && "multiple results not supported yet");
23   std::function<Value(AffineExpr)> buildExpr = [&](AffineExpr e) -> Value {
24     switch (e.getKind()) {
25     case AffineExprKind::Constant:
26       return b.create<ConstantIndexOp>(loc,
27                                        cast<AffineConstantExpr>(e).getValue());
28     case AffineExprKind::DimId:
29       return operands[cast<AffineDimExpr>(e).getPosition()];
30     case AffineExprKind::SymbolId:
31       return operands[cast<AffineSymbolExpr>(e).getPosition() +
32                       map.getNumDims()];
33     case AffineExprKind::Add: {
34       auto binaryExpr = cast<AffineBinaryOpExpr>(e);
35       return b.create<AddIOp>(loc, buildExpr(binaryExpr.getLHS()),
36                               buildExpr(binaryExpr.getRHS()));
37     }
38     case AffineExprKind::Mul: {
39       auto binaryExpr = cast<AffineBinaryOpExpr>(e);
40       return b.create<MulIOp>(loc, buildExpr(binaryExpr.getLHS()),
41                               buildExpr(binaryExpr.getRHS()));
42     }
43     case AffineExprKind::FloorDiv: {
44       auto binaryExpr = cast<AffineBinaryOpExpr>(e);
45       return b.create<DivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
46                                buildExpr(binaryExpr.getRHS()));
47     }
48     case AffineExprKind::CeilDiv: {
49       auto binaryExpr = cast<AffineBinaryOpExpr>(e);
50       return b.create<CeilDivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
51                                    buildExpr(binaryExpr.getRHS()));
52     }
53     case AffineExprKind::Mod: {
54       auto binaryExpr = cast<AffineBinaryOpExpr>(e);
55       return b.create<RemSIOp>(loc, buildExpr(binaryExpr.getLHS()),
56                                buildExpr(binaryExpr.getRHS()));
57     }
58     }
59     llvm_unreachable("unsupported AffineExpr kind");
60   };
61   return buildExpr(map.getResult(0));
62 }
63 
reifyValueBound(OpBuilder & b,Location loc,presburger::BoundType type,const ValueBoundsConstraintSet::Variable & var,ValueBoundsConstraintSet::StopConditionFn stopCondition,bool closedUB)64 FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
65     OpBuilder &b, Location loc, presburger::BoundType type,
66     const ValueBoundsConstraintSet::Variable &var,
67     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
68   // Compute bound.
69   AffineMap boundMap;
70   ValueDimList mapOperands;
71   if (failed(ValueBoundsConstraintSet::computeBound(
72           boundMap, mapOperands, type, var, stopCondition, closedUB)))
73     return failure();
74 
75   // Materialize tensor.dim/memref.dim ops.
76   SmallVector<Value> operands;
77   for (auto valueDim : mapOperands) {
78     Value value = valueDim.first;
79     std::optional<int64_t> dim = valueDim.second;
80 
81     if (!dim.has_value()) {
82       // This is an index-typed value.
83       assert(value.getType().isIndex() && "expected index type");
84       operands.push_back(value);
85       continue;
86     }
87 
88     assert(cast<ShapedType>(value.getType()).isDynamicDim(*dim) &&
89            "expected dynamic dim");
90     if (isa<RankedTensorType>(value.getType())) {
91       // A tensor dimension is used: generate a tensor.dim.
92       operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
93     } else if (isa<MemRefType>(value.getType())) {
94       // A memref dimension is used: generate a memref.dim.
95       operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
96     } else {
97       llvm_unreachable("cannot generate DimOp for unsupported shaped type");
98     }
99   }
100 
101   // Check for special cases where no arith ops are needed.
102   if (boundMap.isSingleConstant()) {
103     // Bound is a constant: return an IntegerAttr.
104     return static_cast<OpFoldResult>(
105         b.getIndexAttr(boundMap.getSingleConstantResult()));
106   }
107   // No arith ops are needed if the bound is a single SSA value.
108   if (auto expr = dyn_cast<AffineDimExpr>(boundMap.getResult(0)))
109     return static_cast<OpFoldResult>(operands[expr.getPosition()]);
110   if (auto expr = dyn_cast<AffineSymbolExpr>(boundMap.getResult(0)))
111     return static_cast<OpFoldResult>(
112         operands[expr.getPosition() + boundMap.getNumDims()]);
113   // General case: build Arith ops.
114   return static_cast<OpFoldResult>(buildArithValue(b, loc, boundMap, operands));
115 }
116 
reifyShapedValueDimBound(OpBuilder & b,Location loc,presburger::BoundType type,Value value,int64_t dim,ValueBoundsConstraintSet::StopConditionFn stopCondition,bool closedUB)117 FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
118     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
119     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
120     bool closedUB) {
121   auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
122                              ValueBoundsConstraintSet &cstr) {
123     // We are trying to reify a bound for `value` in terms of the owning op's
124     // operands. Construct a stop condition that evaluates to "true" for any SSA
125     // value expect for `value`. I.e., the bound will be computed in terms of
126     // any SSA values expect for `value`. The first such values are operands of
127     // the owner of `value`.
128     return v != value;
129   };
130   return reifyValueBound(b, loc, type, {value, dim},
131                          stopCondition ? stopCondition : reifyToOperands,
132                          closedUB);
133 }
134 
reifyIndexValueBound(OpBuilder & b,Location loc,presburger::BoundType type,Value value,ValueBoundsConstraintSet::StopConditionFn stopCondition,bool closedUB)135 FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
136     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
137     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
138   auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
139                              ValueBoundsConstraintSet &cstr) {
140     return v != value;
141   };
142   return reifyValueBound(b, loc, type, value,
143                          stopCondition ? stopCondition : reifyToOperands,
144                          closedUB);
145 }
146