1 //===- EndomorphismSimplification.h -----------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_ 10 #define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_ 11 12 #include "mlir/Transforms/HomomorphismSimplification.h" 13 14 namespace mlir { 15 16 namespace detail { 17 struct CreateAlgebraicOpForEndomorphismSimplification { operatorCreateAlgebraicOpForEndomorphismSimplification18 Operation *operator()(Operation *op, IRMapping &operandsRemapping, 19 PatternRewriter &rewriter) const { 20 return rewriter.clone(*op, operandsRemapping); 21 } 22 }; 23 } // namespace detail 24 25 // If `f` is an endomorphism with respect to the algebraic structure induced by 26 // function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into 27 // `f(g(x1, x2, ..., xn))`. 28 // `g` is the algebraic operation and `f` is the endomorphism. 29 // 30 // Functors: 31 // --------- 32 // `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*` 33 // Returns the operand relevant to the endomorphism. 34 // There may be other operands that are not relevant. 35 // 36 // `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult` 37 // Returns the result relevant to the endomorphism. 38 // 39 // `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void` 40 // Populates into the vector the operands relevant to the endomorphism. 41 // 42 // `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult` 43 // Return the result relevant to the endomorphism. 44 // 45 // `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool` 46 // Check if the operation is an endomorphism of the required type. 47 // Additionally if the optional is present checks if the operations are 48 // compatible endomorphisms. 49 // 50 // `IsAlgebraicOpFn`: `(Operation*) -> bool` 51 // Check if the operation is an operation of the algebraic structure. 52 template <typename GetEndomorphismOpOperandFn, 53 typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn, 54 typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn, 55 typename IsAlgebraicOpFn> 56 struct EndomorphismSimplification 57 : HomomorphismSimplification< 58 GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn, 59 GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn, 60 GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn, 61 detail::CreateAlgebraicOpForEndomorphismSimplification> { 62 template <typename GetEndomorphismOpOperandFnArg, 63 typename GetEndomorphismOpResultFnArg, 64 typename GetAlgebraicOpOperandsFnArg, 65 typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg, 66 typename IsAlgebraicOpFnArg, typename... RewritePatternArgs> EndomorphismSimplificationEndomorphismSimplification67 EndomorphismSimplification( 68 GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand, 69 GetEndomorphismOpResultFnArg &&getEndomorphismOpResult, 70 GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands, 71 GetAlgebraicOpResultFnArg &&getAlgebraicOpResult, 72 IsEndomorphismOpFnArg &&isEndomorphismOp, 73 IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args) 74 : HomomorphismSimplification< 75 GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn, 76 GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn, 77 GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn, 78 detail::CreateAlgebraicOpForEndomorphismSimplification>( 79 std::forward<GetEndomorphismOpOperandFnArg>( 80 getEndomorphismOpOperand), 81 std::forward<GetEndomorphismOpResultFnArg>(getEndomorphismOpResult), 82 std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands), 83 std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult), 84 std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult), 85 std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp), 86 std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp), 87 detail::CreateAlgebraicOpForEndomorphismSimplification(), 88 std::forward<RewritePatternArgs>(args)...) {} 89 }; 90 91 } // namespace mlir 92 93 #endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_ 94