xref: /llvm-project/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp (revision c6f67b8e39a907fb96b715cae3ee90e4c1b248aa)
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
13 
14 using namespace mlir;
15 using namespace mlir::affine;
16 
17 namespace mlir {
18 namespace {
19 
20 struct AffineApplyOpInterface
21     : public ValueBoundsOpInterface::ExternalModel<AffineApplyOpInterface,
22                                                    AffineApplyOp> {
23   void populateBoundsForIndexValue(Operation *op, Value value,
24                                    ValueBoundsConstraintSet &cstr) const {
25     auto applyOp = cast<AffineApplyOp>(op);
26     assert(value == applyOp.getResult() && "invalid value");
27     assert(applyOp.getAffineMap().getNumResults() == 1 &&
28            "expected single result");
29 
30     // Fully compose this affine.apply with other ops because the folding logic
31     // can see opportunities for simplifying the affine map that
32     // `FlatLinearConstraints` can currently not see.
33     AffineMap map = applyOp.getAffineMap();
34     SmallVector<Value> operands = llvm::to_vector(applyOp.getOperands());
35     fullyComposeAffineMapAndOperands(&map, &operands);
36 
37     // Align affine map result with dims/symbols in the constraint set.
38     AffineExpr expr = map.getResult(0);
39     SmallVector<AffineExpr> dimReplacements, symReplacements;
40     for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
41       dimReplacements.push_back(cstr.getExpr(operands[i]));
42     for (int64_t i = map.getNumDims(),
43                  e = map.getNumDims() + map.getNumSymbols();
44          i < e; ++i)
45       symReplacements.push_back(cstr.getExpr(operands[i]));
46     AffineExpr bound =
47         expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
48     cstr.bound(value) == bound;
49   }
50 };
51 
52 struct AffineMinOpInterface
53     : public ValueBoundsOpInterface::ExternalModel<AffineMinOpInterface,
54                                                    AffineMinOp> {
55   void populateBoundsForIndexValue(Operation *op, Value value,
56                                    ValueBoundsConstraintSet &cstr) const {
57     auto minOp = cast<AffineMinOp>(op);
58     assert(value == minOp.getResult() && "invalid value");
59 
60     // Align affine map results with dims/symbols in the constraint set.
61     for (AffineExpr expr : minOp.getAffineMap().getResults()) {
62       SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
63           minOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
64       SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
65           minOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
66       AffineExpr bound =
67           expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
68       cstr.bound(value) <= bound;
69     }
70   };
71 };
72 
73 struct AffineMaxOpInterface
74     : public ValueBoundsOpInterface::ExternalModel<AffineMaxOpInterface,
75                                                    AffineMaxOp> {
76   void populateBoundsForIndexValue(Operation *op, Value value,
77                                    ValueBoundsConstraintSet &cstr) const {
78     auto maxOp = cast<AffineMaxOp>(op);
79     assert(value == maxOp.getResult() && "invalid value");
80 
81     // Align affine map results with dims/symbols in the constraint set.
82     for (AffineExpr expr : maxOp.getAffineMap().getResults()) {
83       SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
84           maxOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
85       SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
86           maxOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
87       AffineExpr bound =
88           expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
89       cstr.bound(value) >= bound;
90     }
91   };
92 };
93 
94 struct AffineDelinearizeIndexOpInterface
95     : public ValueBoundsOpInterface::ExternalModel<
96           AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
97   void populateBoundsForIndexValue(Operation *rawOp, Value value,
98                                    ValueBoundsConstraintSet &cstr) const {
99     auto op = cast<AffineDelinearizeIndexOp>(rawOp);
100     auto result = cast<OpResult>(value);
101     assert(result.getOwner() == rawOp &&
102            "bounded value isn't a result of this delinearize_index");
103     unsigned resIdx = result.getResultNumber();
104 
105     AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex());
106 
107     SmallVector<OpFoldResult> basis = op.getPaddedBasis();
108     AffineExpr divisor = cstr.getExpr(1);
109     for (OpFoldResult basisElem : llvm::drop_begin(basis, resIdx + 1))
110       divisor = divisor * cstr.getExpr(basisElem);
111 
112     if (resIdx == 0) {
113       cstr.bound(value) == linearIdx.floorDiv(divisor);
114       if (!basis.front().isNull())
115         cstr.bound(value) < cstr.getExpr(basis.front());
116       return;
117     }
118     AffineExpr thisBasis = cstr.getExpr(basis[resIdx]);
119     cstr.bound(value) == (linearIdx % (thisBasis * divisor)).floorDiv(divisor);
120   }
121 };
122 
123 struct AffineLinearizeIndexOpInterface
124     : public ValueBoundsOpInterface::ExternalModel<
125           AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
126   void populateBoundsForIndexValue(Operation *rawOp, Value value,
127                                    ValueBoundsConstraintSet &cstr) const {
128     auto op = cast<AffineLinearizeIndexOp>(rawOp);
129     assert(value == op.getResult() &&
130            "value isn't the result of this linearize");
131 
132     AffineExpr bound = cstr.getExpr(0);
133     AffineExpr stride = cstr.getExpr(1);
134     SmallVector<OpFoldResult> basis = op.getPaddedBasis();
135     OperandRange multiIndex = op.getMultiIndex();
136     unsigned numArgs = multiIndex.size();
137     for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) {
138       unsigned argNum = numArgs - (revArgNum + 1);
139       if (argNum == 0)
140         break;
141       OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]);
142       bound = bound + cstr.getExpr(indexAsFoldRes) * stride;
143       stride = stride * cstr.getExpr(length);
144     }
145     bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride;
146     cstr.bound(value) == bound;
147     if (op.getDisjoint() && !basis.front().isNull()) {
148       cstr.bound(value) < stride *cstr.getExpr(basis.front());
149     }
150   }
151 };
152 } // namespace
153 } // namespace mlir
154 
155 void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
156     DialectRegistry &registry) {
157   registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) {
158     AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
159     AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
160     AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
161     AffineDelinearizeIndexOp::attachInterface<
162         AffineDelinearizeIndexOpInterface>(*ctx);
163     AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>(
164         *ctx);
165   });
166 }
167 
168 FailureOr<int64_t>
169 mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
170   assert(value1.getType().isIndex() && "expected index type");
171   assert(value2.getType().isIndex() && "expected index type");
172 
173   // Subtract the two values/dimensions from each other. If the result is 0,
174   // both are equal.
175   Builder b(value1.getContext());
176   AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
177                                  b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
178   // Fully compose the affine map with other ops because the folding logic
179   // can see opportunities for simplifying the affine map that
180   // `FlatLinearConstraints` can currently not see.
181   SmallVector<Value> mapOperands;
182   mapOperands.push_back(value1);
183   mapOperands.push_back(value2);
184   affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
185   return ValueBoundsConstraintSet::computeConstantBound(
186       presburger::BoundType::EQ,
187       ValueBoundsConstraintSet::Variable(map, mapOperands));
188 }
189