xref: /llvm-project/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp (revision 39ac64c1c0fc61a476aa22c53e6977608ead03cf)
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace arith {
18 namespace {
19 
20 struct AddIOpInterface
21     : public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> {
22   void populateBoundsForIndexValue(Operation *op, Value value,
23                                    ValueBoundsConstraintSet &cstr) const {
24     auto addIOp = cast<AddIOp>(op);
25     assert(value == addIOp.getResult() && "invalid value");
26 
27     // Note: `getExpr` has a side effect: it may add a new column to the
28     // constraint system. The evaluation order of addition operands is
29     // unspecified in C++. To make sure that all compilers produce the exact
30     // same results (that can be FileCheck'd), it is important that `getExpr`
31     // is called first and assigned to temporary variables, and the addition
32     // is performed afterwards.
33     AffineExpr lhs = cstr.getExpr(addIOp.getLhs());
34     AffineExpr rhs = cstr.getExpr(addIOp.getRhs());
35     cstr.bound(value) == lhs + rhs;
36   }
37 };
38 
39 struct ConstantOpInterface
40     : public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface,
41                                                    ConstantOp> {
42   void populateBoundsForIndexValue(Operation *op, Value value,
43                                    ValueBoundsConstraintSet &cstr) const {
44     auto constantOp = cast<ConstantOp>(op);
45     assert(value == constantOp.getResult() && "invalid value");
46 
47     if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue()))
48       cstr.bound(value) == attr.getInt();
49   }
50 };
51 
52 struct SubIOpInterface
53     : public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> {
54   void populateBoundsForIndexValue(Operation *op, Value value,
55                                    ValueBoundsConstraintSet &cstr) const {
56     auto subIOp = cast<SubIOp>(op);
57     assert(value == subIOp.getResult() && "invalid value");
58 
59     AffineExpr lhs = cstr.getExpr(subIOp.getLhs());
60     AffineExpr rhs = cstr.getExpr(subIOp.getRhs());
61     cstr.bound(value) == lhs - rhs;
62   }
63 };
64 
65 struct MulIOpInterface
66     : public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> {
67   void populateBoundsForIndexValue(Operation *op, Value value,
68                                    ValueBoundsConstraintSet &cstr) const {
69     auto mulIOp = cast<MulIOp>(op);
70     assert(value == mulIOp.getResult() && "invalid value");
71 
72     AffineExpr lhs = cstr.getExpr(mulIOp.getLhs());
73     AffineExpr rhs = cstr.getExpr(mulIOp.getRhs());
74     cstr.bound(value) == lhs *rhs;
75   }
76 };
77 
78 struct SelectOpInterface
79     : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
80                                                    SelectOp> {
81 
82   static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
83                              ValueBoundsConstraintSet &cstr) {
84     Value value = selectOp.getResult();
85     Value condition = selectOp.getCondition();
86     Value trueValue = selectOp.getTrueValue();
87     Value falseValue = selectOp.getFalseValue();
88 
89     if (isa<ShapedType>(condition.getType())) {
90       // If the condition is a shaped type, the condition is applied
91       // element-wise. All three operands must have the same shape.
92       cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
93       cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
94       cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
95       return;
96     }
97 
98     // Populate constraints for the true/false values (and all values on the
99     // backward slice, as long as the current stop condition is not satisfied).
100     cstr.populateConstraints(trueValue, dim);
101     cstr.populateConstraints(falseValue, dim);
102     auto boundsBuilder = cstr.bound(value);
103     if (dim)
104       boundsBuilder[*dim];
105 
106     // Compare yielded values.
107     // If trueValue <= falseValue:
108     // * result <= falseValue
109     // * result >= trueValue
110     if (cstr.populateAndCompare(
111             /*lhs=*/{trueValue, dim},
112             ValueBoundsConstraintSet::ComparisonOperator::LE,
113             /*rhs=*/{falseValue, dim})) {
114       if (dim) {
115         cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
116         cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
117       } else {
118         cstr.bound(value) >= trueValue;
119         cstr.bound(value) <= falseValue;
120       }
121     }
122     // If falseValue <= trueValue:
123     // * result <= trueValue
124     // * result >= falseValue
125     if (cstr.populateAndCompare(
126             /*lhs=*/{falseValue, dim},
127             ValueBoundsConstraintSet::ComparisonOperator::LE,
128             /*rhs=*/{trueValue, dim})) {
129       if (dim) {
130         cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
131         cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
132       } else {
133         cstr.bound(value) >= falseValue;
134         cstr.bound(value) <= trueValue;
135       }
136     }
137   }
138 
139   void populateBoundsForIndexValue(Operation *op, Value value,
140                                    ValueBoundsConstraintSet &cstr) const {
141     populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
142   }
143 
144   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
145                                        ValueBoundsConstraintSet &cstr) const {
146     populateBounds(cast<SelectOp>(op), dim, cstr);
147   }
148 };
149 } // namespace
150 } // namespace arith
151 } // namespace mlir
152 
153 void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
154     DialectRegistry &registry) {
155   registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
156     arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx);
157     arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
158     arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
159     arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
160     arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
161   });
162 }
163