xref: /llvm-project/mlir/include/mlir/Transforms/HomomorphismSimplification.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- HomomorphismSimplification.h -----------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
10 #define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
11 
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Support/Casting.h"
18 #include <iterator>
19 #include <optional>
20 #include <type_traits>
21 #include <utility>
22 
23 namespace mlir {
24 
25 // If `h` is an homomorphism with respect to the source algebraic structure
26 // induced by function `s` and the target algebraic structure induced by
27 // function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into
28 // `h(t(x1, x2, ..., xn))`.
29 //
30 // Functors:
31 // ---------
32 // `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
33 // Returns the operand relevant to the homomorphism.
34 // There may be other operands that are not relevant.
35 //
36 // `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult`
37 // Returns the result relevant to the homomorphism.
38 //
39 // `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) ->
40 // void` Populates into the vector the operands relevant to the homomorphism.
41 //
42 // `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult`
43 //  Return the result of the source algebraic operation relevant to the
44 //  homomorphism.
45 //
46 // `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
47 //  Return the result of the target algebraic operation relevant to the
48 //  homomorphism.
49 //
50 // `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
51 // Check if the operation is an homomorphism of the required type.
52 // Additionally if the optional is present checks if the operations are
53 // compatible homomorphisms.
54 //
55 // `IsSourceAlgebraicOpFn`: `(Operation*) -> bool`
56 // Check if the operation is an operation of the algebraic structure.
57 //
58 // `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping,
59 // PatternRewriter &rewriter) -> Operation*`
60 template <typename GetHomomorphismOpOperandFn,
61           typename GetHomomorphismOpResultFn,
62           typename GetSourceAlgebraicOpOperandsFn,
63           typename GetSourceAlgebraicOpResultFn,
64           typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn,
65           typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn>
66 struct HomomorphismSimplification : public RewritePattern {
67   template <typename GetHomomorphismOpOperandFnArg,
68             typename GetHomomorphismOpResultFnArg,
69             typename GetSourceAlgebraicOpOperandsFnArg,
70             typename GetSourceAlgebraicOpResultFnArg,
71             typename GetTargetAlgebraicOpResultFnArg,
72             typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg,
73             typename CreateTargetAlgebraicOpFnArg,
74             typename... RewritePatternArgs>
HomomorphismSimplificationHomomorphismSimplification75   HomomorphismSimplification(
76       GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
77       GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
78       GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
79       GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
80       GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
81       IsHomomorphismOpFnArg &&isHomomorphismOp,
82       IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
83       CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
84       RewritePatternArgs &&...args)
85       : RewritePattern(std::forward<RewritePatternArgs>(args)...),
86         getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
87             getHomomorphismOpOperand)),
88         getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
89             getHomomorphismOpResult)),
90         getSourceAlgebraicOpOperands(
91             std::forward<GetSourceAlgebraicOpOperandsFnArg>(
92                 getSourceAlgebraicOpOperands)),
93         getSourceAlgebraicOpResult(
94             std::forward<GetSourceAlgebraicOpResultFnArg>(
95                 getSourceAlgebraicOpResult)),
96         getTargetAlgebraicOpResult(
97             std::forward<GetTargetAlgebraicOpResultFnArg>(
98                 getTargetAlgebraicOpResult)),
99         isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
100         isSourceAlgebraicOp(
101             std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
102         createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
103             createTargetAlgebraicOpFn)) {}
104 
matchAndRewriteHomomorphismSimplification105   LogicalResult matchAndRewrite(Operation *op,
106                                 PatternRewriter &rewriter) const override {
107     SmallVector<OpOperand *> algebraicOpOperands;
108     if (failed(matchOp(op, algebraicOpOperands))) {
109       return failure();
110     }
111     return rewriteOp(op, algebraicOpOperands, rewriter);
112   }
113 
114 private:
115   LogicalResult
matchOpHomomorphismSimplification116   matchOp(Operation *sourceAlgebraicOp,
117           SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const {
118     if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
119       return failure();
120     }
121     sourceAlgebraicOpOperands.clear();
122     getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
123     if (sourceAlgebraicOpOperands.empty()) {
124       return failure();
125     }
126 
127     Operation *firstHomomorphismOp =
128         sourceAlgebraicOpOperands.front()->get().getDefiningOp();
129     if (!firstHomomorphismOp ||
130         !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
131       return failure();
132     }
133     OpResult firstHomomorphismOpResult =
134         getHomomorphismOpResult(firstHomomorphismOp);
135     if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
136       return failure();
137     }
138 
139     for (auto operand : sourceAlgebraicOpOperands) {
140       Operation *homomorphismOp = operand->get().getDefiningOp();
141       if (!homomorphismOp ||
142           !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
143         return failure();
144       }
145     }
146     return success();
147   }
148 
149   LogicalResult
rewriteOpHomomorphismSimplification150   rewriteOp(Operation *sourceAlgebraicOp,
151             const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
152             PatternRewriter &rewriter) const {
153     IRMapping irMapping;
154     for (auto operand : sourceAlgebraicOpOperands) {
155       Operation *homomorphismOp = operand->get().getDefiningOp();
156       irMapping.map(operand->get(),
157                     getHomomorphismOpOperand(homomorphismOp)->get());
158     }
159     Operation *targetAlgebraicOp =
160         createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
161 
162     irMapping.clear();
163     assert(!sourceAlgebraicOpOperands.empty());
164     Operation *firstHomomorphismOp =
165         sourceAlgebraicOpOperands[0]->get().getDefiningOp();
166     irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(),
167                   getTargetAlgebraicOpResult(targetAlgebraicOp));
168     Operation *newHomomorphismOp =
169         rewriter.clone(*firstHomomorphismOp, irMapping);
170     rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
171                                 getHomomorphismOpResult(newHomomorphismOp));
172     return success();
173   }
174 
175   GetHomomorphismOpOperandFn getHomomorphismOpOperand;
176   GetHomomorphismOpResultFn getHomomorphismOpResult;
177   GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
178   GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
179   GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
180   IsHomomorphismOpFn isHomomorphismOp;
181   IsSourceAlgebraicOpFn isSourceAlgebraicOp;
182   CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
183 };
184 
185 } // namespace mlir
186 
187 #endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
188