xref: /llvm-project/mlir/include/mlir/Dialect/CommonFolders.h (revision 650635586220aa8878397579744b71effb35938e)
1 //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various common operation folders. These folders
10 // are intended to be used by dialects to support common folding behavior
11 // without requiring each dialect to provide its own implementation.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
17 
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include <optional>
23 
24 namespace mlir {
25 namespace ub {
26 class PoisonAttr;
27 }
28 /// Performs constant folding `calculate` with element-wise behavior on the two
29 /// attributes in `operands` and returns the result if possible.
30 /// Uses `resultType` for the type of the returned attribute.
31 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
32 /// which will be directly propagated to result.
33 template <class AttrElementT,
34           class ElementValueT = typename AttrElementT::ValueType,
35           class PoisonAttr = ub::PoisonAttr,
36           class CalculationT = function_ref<
37               std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
constFoldBinaryOpConditional(ArrayRef<Attribute> operands,Type resultType,CalculationT && calculate)38 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
39                                        Type resultType,
40                                        CalculationT &&calculate) {
41   assert(operands.size() == 2 && "binary op takes two operands");
42   static_assert(
43       std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
44       "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
45       "void as template argument to opt-out from poison semantics.");
46   if constexpr (!std::is_void_v<PoisonAttr>) {
47     if (isa_and_nonnull<PoisonAttr>(operands[0]))
48       return operands[0];
49 
50     if (isa_and_nonnull<PoisonAttr>(operands[1]))
51       return operands[1];
52   }
53 
54   if (!resultType || !operands[0] || !operands[1])
55     return {};
56 
57   if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
58     auto lhs = cast<AttrElementT>(operands[0]);
59     auto rhs = cast<AttrElementT>(operands[1]);
60     if (lhs.getType() != rhs.getType())
61       return {};
62 
63     auto calRes = calculate(lhs.getValue(), rhs.getValue());
64 
65     if (!calRes)
66       return {};
67 
68     return AttrElementT::get(resultType, *calRes);
69   }
70 
71   if (isa<SplatElementsAttr>(operands[0]) &&
72       isa<SplatElementsAttr>(operands[1])) {
73     // Both operands are splats so we can avoid expanding the values out and
74     // just fold based on the splat value.
75     auto lhs = cast<SplatElementsAttr>(operands[0]);
76     auto rhs = cast<SplatElementsAttr>(operands[1]);
77     if (lhs.getType() != rhs.getType())
78       return {};
79 
80     auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
81                                    rhs.getSplatValue<ElementValueT>());
82     if (!elementResult)
83       return {};
84 
85     return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
86   }
87 
88   if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
89     // Operands are ElementsAttr-derived; perform an element-wise fold by
90     // expanding the values.
91     auto lhs = cast<ElementsAttr>(operands[0]);
92     auto rhs = cast<ElementsAttr>(operands[1]);
93     if (lhs.getType() != rhs.getType())
94       return {};
95 
96     auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
97     auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
98     if (!maybeLhsIt || !maybeRhsIt)
99       return {};
100     auto lhsIt = *maybeLhsIt;
101     auto rhsIt = *maybeRhsIt;
102     SmallVector<ElementValueT, 4> elementResults;
103     elementResults.reserve(lhs.getNumElements());
104     for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
105       auto elementResult = calculate(*lhsIt, *rhsIt);
106       if (!elementResult)
107         return {};
108       elementResults.push_back(*elementResult);
109     }
110 
111     return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
112   }
113   return {};
114 }
115 
116 /// Performs constant folding `calculate` with element-wise behavior on the two
117 /// attributes in `operands` and returns the result if possible.
118 /// Uses the operand element type for the element type of the returned
119 /// attribute.
120 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
121 /// which will be directly propagated to result.
122 template <class AttrElementT,
123           class ElementValueT = typename AttrElementT::ValueType,
124           class PoisonAttr = ub::PoisonAttr,
125           class CalculationT = function_ref<
126               std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
constFoldBinaryOpConditional(ArrayRef<Attribute> operands,CalculationT && calculate)127 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
128                                        CalculationT &&calculate) {
129   assert(operands.size() == 2 && "binary op takes two operands");
130   static_assert(
131       std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
132       "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
133       "void as template argument to opt-out from poison semantics.");
134   if constexpr (!std::is_void_v<PoisonAttr>) {
135     if (isa_and_nonnull<PoisonAttr>(operands[0]))
136       return operands[0];
137 
138     if (isa_and_nonnull<PoisonAttr>(operands[1]))
139       return operands[1];
140   }
141 
142   auto getResultType = [](Attribute attr) -> Type {
143     if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
144       return typed.getType();
145     return {};
146   };
147 
148   Type lhsType = getResultType(operands[0]);
149   Type rhsType = getResultType(operands[1]);
150   if (!lhsType || !rhsType)
151     return {};
152   if (lhsType != rhsType)
153     return {};
154 
155   return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
156                                       CalculationT>(
157       operands, lhsType, std::forward<CalculationT>(calculate));
158 }
159 
160 template <class AttrElementT,
161           class ElementValueT = typename AttrElementT::ValueType,
162           class PoisonAttr = void,
163           class CalculationT =
164               function_ref<ElementValueT(ElementValueT, ElementValueT)>>
constFoldBinaryOp(ArrayRef<Attribute> operands,Type resultType,CalculationT && calculate)165 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
166                             CalculationT &&calculate) {
167   return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
168       operands, resultType,
169       [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170         return calculate(a, b);
171       });
172 }
173 
174 template <class AttrElementT,
175           class ElementValueT = typename AttrElementT::ValueType,
176           class PoisonAttr = ub::PoisonAttr,
177           class CalculationT =
178               function_ref<ElementValueT(ElementValueT, ElementValueT)>>
constFoldBinaryOp(ArrayRef<Attribute> operands,CalculationT && calculate)179 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
180                             CalculationT &&calculate) {
181   return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
182       operands,
183       [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184         return calculate(a, b);
185       });
186 }
187 
188 /// Performs constant folding `calculate` with element-wise behavior on the one
189 /// attributes in `operands` and returns the result if possible.
190 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
191 /// which will be directly propagated to result.
192 template <class AttrElementT,
193           class ElementValueT = typename AttrElementT::ValueType,
194           class PoisonAttr = ub::PoisonAttr,
195           class CalculationT =
196               function_ref<std::optional<ElementValueT>(ElementValueT)>>
constFoldUnaryOpConditional(ArrayRef<Attribute> operands,CalculationT && calculate)197 Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
198                                       CalculationT &&calculate) {
199   assert(operands.size() == 1 && "unary op takes one operands");
200   if (!operands[0])
201     return {};
202 
203   static_assert(
204       std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
205       "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
206       "void as template argument to opt-out from poison semantics.");
207   if constexpr (!std::is_void_v<PoisonAttr>) {
208     if (isa<PoisonAttr>(operands[0]))
209       return operands[0];
210   }
211 
212   if (isa<AttrElementT>(operands[0])) {
213     auto op = cast<AttrElementT>(operands[0]);
214 
215     auto res = calculate(op.getValue());
216     if (!res)
217       return {};
218     return AttrElementT::get(op.getType(), *res);
219   }
220   if (isa<SplatElementsAttr>(operands[0])) {
221     // Both operands are splats so we can avoid expanding the values out and
222     // just fold based on the splat value.
223     auto op = cast<SplatElementsAttr>(operands[0]);
224 
225     auto elementResult = calculate(op.getSplatValue<ElementValueT>());
226     if (!elementResult)
227       return {};
228     return DenseElementsAttr::get(op.getType(), *elementResult);
229   } else if (isa<ElementsAttr>(operands[0])) {
230     // Operands are ElementsAttr-derived; perform an element-wise fold by
231     // expanding the values.
232     auto op = cast<ElementsAttr>(operands[0]);
233 
234     auto maybeOpIt = op.try_value_begin<ElementValueT>();
235     if (!maybeOpIt)
236       return {};
237     auto opIt = *maybeOpIt;
238     SmallVector<ElementValueT> elementResults;
239     elementResults.reserve(op.getNumElements());
240     for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
241       auto elementResult = calculate(*opIt);
242       if (!elementResult)
243         return {};
244       elementResults.push_back(*elementResult);
245     }
246     return DenseElementsAttr::get(op.getShapedType(), elementResults);
247   }
248   return {};
249 }
250 
251 template <class AttrElementT,
252           class ElementValueT = typename AttrElementT::ValueType,
253           class PoisonAttr = ub::PoisonAttr,
254           class CalculationT = function_ref<ElementValueT(ElementValueT)>>
constFoldUnaryOp(ArrayRef<Attribute> operands,CalculationT && calculate)255 Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
256                            CalculationT &&calculate) {
257   return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
258       operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
259         return calculate(a);
260       });
261 }
262 
263 template <
264     class AttrElementT, class TargetAttrElementT,
265     class ElementValueT = typename AttrElementT::ValueType,
266     class TargetElementValueT = typename TargetAttrElementT::ValueType,
267     class PoisonAttr = ub::PoisonAttr,
268     class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
constFoldCastOp(ArrayRef<Attribute> operands,Type resType,CalculationT && calculate)269 Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
270                           CalculationT &&calculate) {
271   assert(operands.size() == 1 && "Cast op takes one operand");
272   if (!operands[0])
273     return {};
274 
275   static_assert(
276       std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
277       "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
278       "void as template argument to opt-out from poison semantics.");
279   if constexpr (!std::is_void_v<PoisonAttr>) {
280     if (isa<PoisonAttr>(operands[0]))
281       return operands[0];
282   }
283 
284   if (isa<AttrElementT>(operands[0])) {
285     auto op = cast<AttrElementT>(operands[0]);
286     bool castStatus = true;
287     auto res = calculate(op.getValue(), castStatus);
288     if (!castStatus)
289       return {};
290     return TargetAttrElementT::get(resType, res);
291   }
292   if (isa<SplatElementsAttr>(operands[0])) {
293     // The operand is a splat so we can avoid expanding the values out and
294     // just fold based on the splat value.
295     auto op = cast<SplatElementsAttr>(operands[0]);
296     bool castStatus = true;
297     auto elementResult =
298         calculate(op.getSplatValue<ElementValueT>(), castStatus);
299     if (!castStatus)
300       return {};
301     auto shapedResType = cast<ShapedType>(resType);
302     if (!shapedResType.hasStaticShape())
303       return {};
304     return DenseElementsAttr::get(shapedResType, elementResult);
305   }
306   if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
307     // Operand is ElementsAttr-derived; perform an element-wise fold by
308     // expanding the value.
309     bool castStatus = true;
310     auto maybeOpIt = op.try_value_begin<ElementValueT>();
311     if (!maybeOpIt)
312       return {};
313     auto opIt = *maybeOpIt;
314     SmallVector<TargetElementValueT> elementResults;
315     elementResults.reserve(op.getNumElements());
316     for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
317       auto elt = calculate(*opIt, castStatus);
318       if (!castStatus)
319         return {};
320       elementResults.push_back(elt);
321     }
322 
323     return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
324   }
325   return {};
326 }
327 } // namespace mlir
328 
329 #endif // MLIR_DIALECT_COMMONFOLDERS_H
330