1 //===- HomomorphismSimplification.h -----------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ 10 #define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ 11 12 #include "mlir/IR/IRMapping.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/Value.h" 15 #include "mlir/Support/LLVM.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/Support/Casting.h" 18 #include <iterator> 19 #include <optional> 20 #include <type_traits> 21 #include <utility> 22 23 namespace mlir { 24 25 // If `h` is an homomorphism with respect to the source algebraic structure 26 // induced by function `s` and the target algebraic structure induced by 27 // function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into 28 // `h(t(x1, x2, ..., xn))`. 29 // 30 // Functors: 31 // --------- 32 // `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*` 33 // Returns the operand relevant to the homomorphism. 34 // There may be other operands that are not relevant. 35 // 36 // `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult` 37 // Returns the result relevant to the homomorphism. 38 // 39 // `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> 40 // void` Populates into the vector the operands relevant to the homomorphism. 41 // 42 // `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult` 43 // Return the result of the source algebraic operation relevant to the 44 // homomorphism. 45 // 46 // `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult` 47 // Return the result of the target algebraic operation relevant to the 48 // homomorphism. 49 // 50 // `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool` 51 // Check if the operation is an homomorphism of the required type. 52 // Additionally if the optional is present checks if the operations are 53 // compatible homomorphisms. 54 // 55 // `IsSourceAlgebraicOpFn`: `(Operation*) -> bool` 56 // Check if the operation is an operation of the algebraic structure. 57 // 58 // `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping, 59 // PatternRewriter &rewriter) -> Operation*` 60 template <typename GetHomomorphismOpOperandFn, 61 typename GetHomomorphismOpResultFn, 62 typename GetSourceAlgebraicOpOperandsFn, 63 typename GetSourceAlgebraicOpResultFn, 64 typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn, 65 typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn> 66 struct HomomorphismSimplification : public RewritePattern { 67 template <typename GetHomomorphismOpOperandFnArg, 68 typename GetHomomorphismOpResultFnArg, 69 typename GetSourceAlgebraicOpOperandsFnArg, 70 typename GetSourceAlgebraicOpResultFnArg, 71 typename GetTargetAlgebraicOpResultFnArg, 72 typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg, 73 typename CreateTargetAlgebraicOpFnArg, 74 typename... RewritePatternArgs> HomomorphismSimplificationHomomorphismSimplification75 HomomorphismSimplification( 76 GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand, 77 GetHomomorphismOpResultFnArg &&getHomomorphismOpResult, 78 GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands, 79 GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult, 80 GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult, 81 IsHomomorphismOpFnArg &&isHomomorphismOp, 82 IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp, 83 CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn, 84 RewritePatternArgs &&...args) 85 : RewritePattern(std::forward<RewritePatternArgs>(args)...), 86 getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>( 87 getHomomorphismOpOperand)), 88 getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>( 89 getHomomorphismOpResult)), 90 getSourceAlgebraicOpOperands( 91 std::forward<GetSourceAlgebraicOpOperandsFnArg>( 92 getSourceAlgebraicOpOperands)), 93 getSourceAlgebraicOpResult( 94 std::forward<GetSourceAlgebraicOpResultFnArg>( 95 getSourceAlgebraicOpResult)), 96 getTargetAlgebraicOpResult( 97 std::forward<GetTargetAlgebraicOpResultFnArg>( 98 getTargetAlgebraicOpResult)), 99 isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)), 100 isSourceAlgebraicOp( 101 std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)), 102 createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>( 103 createTargetAlgebraicOpFn)) {} 104 matchAndRewriteHomomorphismSimplification105 LogicalResult matchAndRewrite(Operation *op, 106 PatternRewriter &rewriter) const override { 107 SmallVector<OpOperand *> algebraicOpOperands; 108 if (failed(matchOp(op, algebraicOpOperands))) { 109 return failure(); 110 } 111 return rewriteOp(op, algebraicOpOperands, rewriter); 112 } 113 114 private: 115 LogicalResult matchOpHomomorphismSimplification116 matchOp(Operation *sourceAlgebraicOp, 117 SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const { 118 if (!isSourceAlgebraicOp(sourceAlgebraicOp)) { 119 return failure(); 120 } 121 sourceAlgebraicOpOperands.clear(); 122 getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands); 123 if (sourceAlgebraicOpOperands.empty()) { 124 return failure(); 125 } 126 127 Operation *firstHomomorphismOp = 128 sourceAlgebraicOpOperands.front()->get().getDefiningOp(); 129 if (!firstHomomorphismOp || 130 !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) { 131 return failure(); 132 } 133 OpResult firstHomomorphismOpResult = 134 getHomomorphismOpResult(firstHomomorphismOp); 135 if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) { 136 return failure(); 137 } 138 139 for (auto operand : sourceAlgebraicOpOperands) { 140 Operation *homomorphismOp = operand->get().getDefiningOp(); 141 if (!homomorphismOp || 142 !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) { 143 return failure(); 144 } 145 } 146 return success(); 147 } 148 149 LogicalResult rewriteOpHomomorphismSimplification150 rewriteOp(Operation *sourceAlgebraicOp, 151 const SmallVector<OpOperand *> &sourceAlgebraicOpOperands, 152 PatternRewriter &rewriter) const { 153 IRMapping irMapping; 154 for (auto operand : sourceAlgebraicOpOperands) { 155 Operation *homomorphismOp = operand->get().getDefiningOp(); 156 irMapping.map(operand->get(), 157 getHomomorphismOpOperand(homomorphismOp)->get()); 158 } 159 Operation *targetAlgebraicOp = 160 createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter); 161 162 irMapping.clear(); 163 assert(!sourceAlgebraicOpOperands.empty()); 164 Operation *firstHomomorphismOp = 165 sourceAlgebraicOpOperands[0]->get().getDefiningOp(); 166 irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(), 167 getTargetAlgebraicOpResult(targetAlgebraicOp)); 168 Operation *newHomomorphismOp = 169 rewriter.clone(*firstHomomorphismOp, irMapping); 170 rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp), 171 getHomomorphismOpResult(newHomomorphismOp)); 172 return success(); 173 } 174 175 GetHomomorphismOpOperandFn getHomomorphismOpOperand; 176 GetHomomorphismOpResultFn getHomomorphismOpResult; 177 GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands; 178 GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult; 179 GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult; 180 IsHomomorphismOpFn isHomomorphismOp; 181 IsSourceAlgebraicOpFn isSourceAlgebraicOp; 182 CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn; 183 }; 184 185 } // namespace mlir 186 187 #endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ 188