1 //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various common operation folders. These folders
10 // are intended to be used by dialects to support common folding behavior
11 // without requiring each dialect to provide its own implementation.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
17
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include <optional>
23
24 namespace mlir {
25 namespace ub {
26 class PoisonAttr;
27 }
28 /// Performs constant folding `calculate` with element-wise behavior on the two
29 /// attributes in `operands` and returns the result if possible.
30 /// Uses `resultType` for the type of the returned attribute.
31 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
32 /// which will be directly propagated to result.
33 template <class AttrElementT,
34 class ElementValueT = typename AttrElementT::ValueType,
35 class PoisonAttr = ub::PoisonAttr,
36 class CalculationT = function_ref<
37 std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
constFoldBinaryOpConditional(ArrayRef<Attribute> operands,Type resultType,CalculationT && calculate)38 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
39 Type resultType,
40 CalculationT &&calculate) {
41 assert(operands.size() == 2 && "binary op takes two operands");
42 static_assert(
43 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
44 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
45 "void as template argument to opt-out from poison semantics.");
46 if constexpr (!std::is_void_v<PoisonAttr>) {
47 if (isa_and_nonnull<PoisonAttr>(operands[0]))
48 return operands[0];
49
50 if (isa_and_nonnull<PoisonAttr>(operands[1]))
51 return operands[1];
52 }
53
54 if (!resultType || !operands[0] || !operands[1])
55 return {};
56
57 if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
58 auto lhs = cast<AttrElementT>(operands[0]);
59 auto rhs = cast<AttrElementT>(operands[1]);
60 if (lhs.getType() != rhs.getType())
61 return {};
62
63 auto calRes = calculate(lhs.getValue(), rhs.getValue());
64
65 if (!calRes)
66 return {};
67
68 return AttrElementT::get(resultType, *calRes);
69 }
70
71 if (isa<SplatElementsAttr>(operands[0]) &&
72 isa<SplatElementsAttr>(operands[1])) {
73 // Both operands are splats so we can avoid expanding the values out and
74 // just fold based on the splat value.
75 auto lhs = cast<SplatElementsAttr>(operands[0]);
76 auto rhs = cast<SplatElementsAttr>(operands[1]);
77 if (lhs.getType() != rhs.getType())
78 return {};
79
80 auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
81 rhs.getSplatValue<ElementValueT>());
82 if (!elementResult)
83 return {};
84
85 return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
86 }
87
88 if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
89 // Operands are ElementsAttr-derived; perform an element-wise fold by
90 // expanding the values.
91 auto lhs = cast<ElementsAttr>(operands[0]);
92 auto rhs = cast<ElementsAttr>(operands[1]);
93 if (lhs.getType() != rhs.getType())
94 return {};
95
96 auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
97 auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
98 if (!maybeLhsIt || !maybeRhsIt)
99 return {};
100 auto lhsIt = *maybeLhsIt;
101 auto rhsIt = *maybeRhsIt;
102 SmallVector<ElementValueT, 4> elementResults;
103 elementResults.reserve(lhs.getNumElements());
104 for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
105 auto elementResult = calculate(*lhsIt, *rhsIt);
106 if (!elementResult)
107 return {};
108 elementResults.push_back(*elementResult);
109 }
110
111 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
112 }
113 return {};
114 }
115
116 /// Performs constant folding `calculate` with element-wise behavior on the two
117 /// attributes in `operands` and returns the result if possible.
118 /// Uses the operand element type for the element type of the returned
119 /// attribute.
120 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
121 /// which will be directly propagated to result.
122 template <class AttrElementT,
123 class ElementValueT = typename AttrElementT::ValueType,
124 class PoisonAttr = ub::PoisonAttr,
125 class CalculationT = function_ref<
126 std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
constFoldBinaryOpConditional(ArrayRef<Attribute> operands,CalculationT && calculate)127 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
128 CalculationT &&calculate) {
129 assert(operands.size() == 2 && "binary op takes two operands");
130 static_assert(
131 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
132 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
133 "void as template argument to opt-out from poison semantics.");
134 if constexpr (!std::is_void_v<PoisonAttr>) {
135 if (isa_and_nonnull<PoisonAttr>(operands[0]))
136 return operands[0];
137
138 if (isa_and_nonnull<PoisonAttr>(operands[1]))
139 return operands[1];
140 }
141
142 auto getResultType = [](Attribute attr) -> Type {
143 if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
144 return typed.getType();
145 return {};
146 };
147
148 Type lhsType = getResultType(operands[0]);
149 Type rhsType = getResultType(operands[1]);
150 if (!lhsType || !rhsType)
151 return {};
152 if (lhsType != rhsType)
153 return {};
154
155 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
156 CalculationT>(
157 operands, lhsType, std::forward<CalculationT>(calculate));
158 }
159
160 template <class AttrElementT,
161 class ElementValueT = typename AttrElementT::ValueType,
162 class PoisonAttr = void,
163 class CalculationT =
164 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
constFoldBinaryOp(ArrayRef<Attribute> operands,Type resultType,CalculationT && calculate)165 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
166 CalculationT &&calculate) {
167 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
168 operands, resultType,
169 [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170 return calculate(a, b);
171 });
172 }
173
174 template <class AttrElementT,
175 class ElementValueT = typename AttrElementT::ValueType,
176 class PoisonAttr = ub::PoisonAttr,
177 class CalculationT =
178 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
constFoldBinaryOp(ArrayRef<Attribute> operands,CalculationT && calculate)179 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
180 CalculationT &&calculate) {
181 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
182 operands,
183 [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184 return calculate(a, b);
185 });
186 }
187
188 /// Performs constant folding `calculate` with element-wise behavior on the one
189 /// attributes in `operands` and returns the result if possible.
190 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
191 /// which will be directly propagated to result.
192 template <class AttrElementT,
193 class ElementValueT = typename AttrElementT::ValueType,
194 class PoisonAttr = ub::PoisonAttr,
195 class CalculationT =
196 function_ref<std::optional<ElementValueT>(ElementValueT)>>
constFoldUnaryOpConditional(ArrayRef<Attribute> operands,CalculationT && calculate)197 Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
198 CalculationT &&calculate) {
199 assert(operands.size() == 1 && "unary op takes one operands");
200 if (!operands[0])
201 return {};
202
203 static_assert(
204 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
205 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
206 "void as template argument to opt-out from poison semantics.");
207 if constexpr (!std::is_void_v<PoisonAttr>) {
208 if (isa<PoisonAttr>(operands[0]))
209 return operands[0];
210 }
211
212 if (isa<AttrElementT>(operands[0])) {
213 auto op = cast<AttrElementT>(operands[0]);
214
215 auto res = calculate(op.getValue());
216 if (!res)
217 return {};
218 return AttrElementT::get(op.getType(), *res);
219 }
220 if (isa<SplatElementsAttr>(operands[0])) {
221 // Both operands are splats so we can avoid expanding the values out and
222 // just fold based on the splat value.
223 auto op = cast<SplatElementsAttr>(operands[0]);
224
225 auto elementResult = calculate(op.getSplatValue<ElementValueT>());
226 if (!elementResult)
227 return {};
228 return DenseElementsAttr::get(op.getType(), *elementResult);
229 } else if (isa<ElementsAttr>(operands[0])) {
230 // Operands are ElementsAttr-derived; perform an element-wise fold by
231 // expanding the values.
232 auto op = cast<ElementsAttr>(operands[0]);
233
234 auto maybeOpIt = op.try_value_begin<ElementValueT>();
235 if (!maybeOpIt)
236 return {};
237 auto opIt = *maybeOpIt;
238 SmallVector<ElementValueT> elementResults;
239 elementResults.reserve(op.getNumElements());
240 for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
241 auto elementResult = calculate(*opIt);
242 if (!elementResult)
243 return {};
244 elementResults.push_back(*elementResult);
245 }
246 return DenseElementsAttr::get(op.getShapedType(), elementResults);
247 }
248 return {};
249 }
250
251 template <class AttrElementT,
252 class ElementValueT = typename AttrElementT::ValueType,
253 class PoisonAttr = ub::PoisonAttr,
254 class CalculationT = function_ref<ElementValueT(ElementValueT)>>
constFoldUnaryOp(ArrayRef<Attribute> operands,CalculationT && calculate)255 Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
256 CalculationT &&calculate) {
257 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
258 operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
259 return calculate(a);
260 });
261 }
262
263 template <
264 class AttrElementT, class TargetAttrElementT,
265 class ElementValueT = typename AttrElementT::ValueType,
266 class TargetElementValueT = typename TargetAttrElementT::ValueType,
267 class PoisonAttr = ub::PoisonAttr,
268 class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
constFoldCastOp(ArrayRef<Attribute> operands,Type resType,CalculationT && calculate)269 Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
270 CalculationT &&calculate) {
271 assert(operands.size() == 1 && "Cast op takes one operand");
272 if (!operands[0])
273 return {};
274
275 static_assert(
276 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
277 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
278 "void as template argument to opt-out from poison semantics.");
279 if constexpr (!std::is_void_v<PoisonAttr>) {
280 if (isa<PoisonAttr>(operands[0]))
281 return operands[0];
282 }
283
284 if (isa<AttrElementT>(operands[0])) {
285 auto op = cast<AttrElementT>(operands[0]);
286 bool castStatus = true;
287 auto res = calculate(op.getValue(), castStatus);
288 if (!castStatus)
289 return {};
290 return TargetAttrElementT::get(resType, res);
291 }
292 if (isa<SplatElementsAttr>(operands[0])) {
293 // The operand is a splat so we can avoid expanding the values out and
294 // just fold based on the splat value.
295 auto op = cast<SplatElementsAttr>(operands[0]);
296 bool castStatus = true;
297 auto elementResult =
298 calculate(op.getSplatValue<ElementValueT>(), castStatus);
299 if (!castStatus)
300 return {};
301 auto shapedResType = cast<ShapedType>(resType);
302 if (!shapedResType.hasStaticShape())
303 return {};
304 return DenseElementsAttr::get(shapedResType, elementResult);
305 }
306 if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
307 // Operand is ElementsAttr-derived; perform an element-wise fold by
308 // expanding the value.
309 bool castStatus = true;
310 auto maybeOpIt = op.try_value_begin<ElementValueT>();
311 if (!maybeOpIt)
312 return {};
313 auto opIt = *maybeOpIt;
314 SmallVector<TargetElementValueT> elementResults;
315 elementResults.reserve(op.getNumElements());
316 for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
317 auto elt = calculate(*opIt, castStatus);
318 if (!castStatus)
319 return {};
320 elementResults.push_back(elt);
321 }
322
323 return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
324 }
325 return {};
326 }
327 } // namespace mlir
328
329 #endif // MLIR_DIALECT_COMMONFOLDERS_H
330