xref: /llvm-project/mlir/include/mlir/Transforms/EndomorphismSimplification.h (revision 4b3446771f745bb5169354ad9027c0a1c9fca394)
1 //===- EndomorphismSimplification.h -----------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
10 #define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
11 
12 #include "mlir/Transforms/HomomorphismSimplification.h"
13 
14 namespace mlir {
15 
16 namespace detail {
17 struct CreateAlgebraicOpForEndomorphismSimplification {
operatorCreateAlgebraicOpForEndomorphismSimplification18   Operation *operator()(Operation *op, IRMapping &operandsRemapping,
19                         PatternRewriter &rewriter) const {
20     return rewriter.clone(*op, operandsRemapping);
21   }
22 };
23 } // namespace detail
24 
25 // If `f` is an endomorphism with respect to the algebraic structure induced by
26 // function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
27 // `f(g(x1, x2, ..., xn))`.
28 // `g` is the algebraic operation and `f` is the endomorphism.
29 //
30 // Functors:
31 // ---------
32 // `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
33 // Returns the operand relevant to the endomorphism.
34 // There may be other operands that are not relevant.
35 //
36 // `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
37 // Returns the result relevant to the endomorphism.
38 //
39 // `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
40 // Populates into the vector the operands relevant to the endomorphism.
41 //
42 // `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
43 //  Return the result relevant to the endomorphism.
44 //
45 // `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
46 // Check if the operation is an endomorphism of the required type.
47 // Additionally if the optional is present checks if the operations are
48 // compatible endomorphisms.
49 //
50 // `IsAlgebraicOpFn`: `(Operation*) -> bool`
51 // Check if the operation is an operation of the algebraic structure.
52 template <typename GetEndomorphismOpOperandFn,
53           typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
54           typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
55           typename IsAlgebraicOpFn>
56 struct EndomorphismSimplification
57     : HomomorphismSimplification<
58           GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
59           GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
60           GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
61           detail::CreateAlgebraicOpForEndomorphismSimplification> {
62   template <typename GetEndomorphismOpOperandFnArg,
63             typename GetEndomorphismOpResultFnArg,
64             typename GetAlgebraicOpOperandsFnArg,
65             typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
66             typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
EndomorphismSimplificationEndomorphismSimplification67   EndomorphismSimplification(
68       GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
69       GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
70       GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
71       GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
72       IsEndomorphismOpFnArg &&isEndomorphismOp,
73       IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
74       : HomomorphismSimplification<
75             GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
76             GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
77             GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
78             detail::CreateAlgebraicOpForEndomorphismSimplification>(
79             std::forward<GetEndomorphismOpOperandFnArg>(
80                 getEndomorphismOpOperand),
81             std::forward<GetEndomorphismOpResultFnArg>(getEndomorphismOpResult),
82             std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands),
83             std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
84             std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
85             std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp),
86             std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp),
87             detail::CreateAlgebraicOpForEndomorphismSimplification(),
88             std::forward<RewritePatternArgs>(args)...) {}
89 };
90 
91 } // namespace mlir
92 
93 #endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
94