xref: /llvm-project/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (revision f8b27949a8c4fa8d8e15f9858e2ed38d7267f7dd)
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/SCF/IR/SCF.h"
12 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace scf {
18 namespace {
19 
20 struct ForOpInterface
21     : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
22 
23   /// Populate bounds of values/dimensions for iter_args/OpResults. If the
24   /// value/dimension size does not change in an iteration, we can deduce that
25   /// it the same as the initial value/dimension.
26   ///
27   /// Example 1:
28   /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
29   ///   ...
30   ///   %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
31   ///   scf.yield %1 : tensor<?xf32>
32   /// }
33   /// --> bound(%0)[0] == bound(%t)[0]
34   /// --> bound(%arg0)[0] == bound(%t)[0]
35   ///
36   /// Example 2:
37   /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
38   ///   %sz = tensor.dim %arg0 : tensor<?xf32>
39   ///   %incr = arith.addi %sz, %c1 : index
40   ///   %1 = tensor.empty(%incr) : tensor<?xf32>
41   ///   scf.yield %1 : tensor<?xf32>
42   /// }
43   /// --> The yielded tensor dimension size changes with each iteration. Such
44   ///     loops are not supported and no constraints are added.
45   static void populateIterArgBounds(scf::ForOp forOp, Value value,
46                                     std::optional<int64_t> dim,
47                                     ValueBoundsConstraintSet &cstr) {
48     // `value` is an iter_arg or an OpResult.
49     int64_t iterArgIdx;
50     if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
51       iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
52     } else {
53       iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
54     }
55 
56     Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
57                              .getOperand(iterArgIdx);
58     Value iterArg = forOp.getRegionIterArg(iterArgIdx);
59     Value initArg = forOp.getInitArgs()[iterArgIdx];
60 
61     // An EQ constraint can be added if the yielded value (dimension size)
62     // equals the corresponding block argument (dimension size).
63     if (cstr.populateAndCompare(
64             /*lhs=*/{yieldedValue, dim},
65             ValueBoundsConstraintSet::ComparisonOperator::EQ,
66             /*rhs=*/{iterArg, dim})) {
67       if (dim.has_value()) {
68         cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
69       } else {
70         cstr.bound(value) == cstr.getExpr(initArg);
71       }
72     }
73 
74     if (dim.has_value() || isa<BlockArgument>(value))
75       return;
76 
77     // `value` is result of `forOp`, we can prove that:
78     // %result == %init_arg + trip_count * (%yielded_value - %iter_arg).
79     // Where trip_count is (ub - lb) / step.
80     AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound());
81     AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound());
82     AffineExpr stepExpr = cstr.getExpr(forOp.getStep());
83     AffineExpr tripCountExpr =
84         AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step
85     AffineExpr oneIterAdvanceExpr =
86         cstr.getExpr(yieldedValue) - cstr.getExpr(iterArg);
87     cstr.bound(value) ==
88         cstr.getExpr(initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr);
89   }
90 
91   void populateBoundsForIndexValue(Operation *op, Value value,
92                                    ValueBoundsConstraintSet &cstr) const {
93     auto forOp = cast<ForOp>(op);
94 
95     if (value == forOp.getInductionVar()) {
96       // TODO: Take into account step size.
97       cstr.bound(value) >= forOp.getLowerBound();
98       cstr.bound(value) < forOp.getUpperBound();
99       return;
100     }
101 
102     // Handle iter_args and OpResults.
103     populateIterArgBounds(forOp, value, std::nullopt, cstr);
104   }
105 
106   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
107                                        ValueBoundsConstraintSet &cstr) const {
108     auto forOp = cast<ForOp>(op);
109     // Handle iter_args and OpResults.
110     populateIterArgBounds(forOp, value, dim, cstr);
111   }
112 };
113 
114 struct ForallOpInterface
115     : public ValueBoundsOpInterface::ExternalModel<ForallOpInterface,
116                                                    ForallOp> {
117 
118   void populateBoundsForIndexValue(Operation *op, Value value,
119                                    ValueBoundsConstraintSet &cstr) const {
120     auto forallOp = cast<ForallOp>(op);
121 
122     // Index values should be induction variables, since the semantics of
123     // tensor::ParallelInsertSliceOp requires forall outputs to be ranked
124     // tensors.
125     auto blockArg = cast<BlockArgument>(value);
126     assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() &&
127            "expected index value to be an induction var");
128     int64_t idx = blockArg.getArgNumber();
129     // TODO: Take into account step size.
130     AffineExpr lb = cstr.getExpr(forallOp.getMixedLowerBound()[idx]);
131     AffineExpr ub = cstr.getExpr(forallOp.getMixedUpperBound()[idx]);
132     cstr.bound(value) >= lb;
133     cstr.bound(value) < ub;
134   }
135 
136   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
137                                        ValueBoundsConstraintSet &cstr) const {
138     auto forallOp = cast<ForallOp>(op);
139 
140     // `value` is an iter_arg or an OpResult.
141     int64_t iterArgIdx;
142     if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
143       iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size();
144     } else {
145       iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
146     }
147 
148     // The forall results and output arguments have the same sizes as the output
149     // operands.
150     Value outputOperand = forallOp.getOutputs()[iterArgIdx];
151     cstr.bound(value)[dim] == cstr.getExpr(outputOperand, dim);
152   }
153 };
154 
155 struct IfOpInterface
156     : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
157 
158   static void populateBounds(scf::IfOp ifOp, Value value,
159                              std::optional<int64_t> dim,
160                              ValueBoundsConstraintSet &cstr) {
161     unsigned int resultNum = cast<OpResult>(value).getResultNumber();
162     Value thenValue = ifOp.thenYield().getResults()[resultNum];
163     Value elseValue = ifOp.elseYield().getResults()[resultNum];
164 
165     auto boundsBuilder = cstr.bound(value);
166     if (dim)
167       boundsBuilder[*dim];
168 
169     // Compare yielded values.
170     // If thenValue <= elseValue:
171     // * result <= elseValue
172     // * result >= thenValue
173     if (cstr.populateAndCompare(
174             /*lhs=*/{thenValue, dim},
175             ValueBoundsConstraintSet::ComparisonOperator::LE,
176             /*rhs=*/{elseValue, dim})) {
177       if (dim) {
178         cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
179         cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
180       } else {
181         cstr.bound(value) >= thenValue;
182         cstr.bound(value) <= elseValue;
183       }
184     }
185     // If elseValue <= thenValue:
186     // * result <= thenValue
187     // * result >= elseValue
188     if (cstr.populateAndCompare(
189             /*lhs=*/{elseValue, dim},
190             ValueBoundsConstraintSet::ComparisonOperator::LE,
191             /*rhs=*/{thenValue, dim})) {
192       if (dim) {
193         cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
194         cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
195       } else {
196         cstr.bound(value) >= elseValue;
197         cstr.bound(value) <= thenValue;
198       }
199     }
200   }
201 
202   void populateBoundsForIndexValue(Operation *op, Value value,
203                                    ValueBoundsConstraintSet &cstr) const {
204     populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
205   }
206 
207   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
208                                        ValueBoundsConstraintSet &cstr) const {
209     populateBounds(cast<IfOp>(op), value, dim, cstr);
210   }
211 };
212 
213 } // namespace
214 } // namespace scf
215 } // namespace mlir
216 
217 void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
218     DialectRegistry &registry) {
219   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
220     scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
221     scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx);
222     scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
223   });
224 }
225