1 //===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
10 #define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
11
12 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Transforms/EndomorphismSimplification.h"
15 #include "llvm/Support/Casting.h"
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <utility>
20
21 namespace mlir {
22
23 class SymbolTableCollection;
24
25 namespace mesh {
26
27 // If we have an algebraic op like "+" and a summing all-reduce,
28 // `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
29 // `all_reduce_sum(x + y)`.
30 //
31 // Another example with `min`.
32 // `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to
33 // `all_reduce_min(min(x, y))`.
34 //
35 // Works only with algebraic ops that have all their operands relevant
36 // to the all-reduce endomorphism.
37 // Will not work with some op `f(x, y, z)` where only `x` and `y` form
38 // the algebraic structure.
39 template <typename AlgebraicOp>
populateAllReduceEndomorphismSimplificationPatterns(RewritePatternSet & patterns,ReductionKind reduction)40 void populateAllReduceEndomorphismSimplificationPatterns(
41 RewritePatternSet &patterns, ReductionKind reduction) {
42 auto getEndomorphismOpOperand = [](Operation *op) {
43 auto allReduceOp = llvm::cast<AllReduceOp>(op);
44 return &allReduceOp.getInputMutable();
45 };
46 auto getEndomorphismOpResult = [](Operation *op) {
47 auto allReduceOp = llvm::cast<AllReduceOp>(op);
48 return allReduceOp->getResult(0);
49 };
50 auto getAlgebraicOpOperands = [](Operation *op,
51 SmallVector<OpOperand *> &operands) {
52 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
53 std::transform(algebraicOp->getOpOperands().begin(),
54 algebraicOp->getOpOperands().end(),
55 std::back_inserter(operands),
56 [](OpOperand &operand) { return &operand; });
57 };
58 auto getAlgebraicOpResult = [](Operation *op) {
59 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
60 return algebraicOp->getResult(0);
61 };
62 auto isEndomorphismOp = [reduction](Operation *op,
63 std::optional<Operation *> referenceOp) {
64 auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
65 if (!allReduceOp ||
66 allReduceOp.getInput().getType().getElementType() !=
67 allReduceOp.getResult().getType().getElementType() ||
68 allReduceOp.getReduction() != reduction) {
69 return false;
70 }
71
72 // Dont't use simplify if the all-reduce is used other than by the
73 // algebraic op.
74 // TODO: maybe handle this by an additional pass that later reverses the
75 // simplification if there are other uses left other optimizations have
76 // been done.
77 if (!allReduceOp->hasOneUse()) {
78 return false;
79 }
80
81 if (!referenceOp) {
82 return true;
83 }
84
85 auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
86 return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
87 allReduceOp.getInput().getType().getElementType() ==
88 refAllReduceOp.getInput().getType().getElementType();
89 };
90 auto isAlgebraicOp = [](Operation *op) {
91 return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
92 };
93
94 using ConcreteEndomorphismSimplification = EndomorphismSimplification<
95 std::decay_t<decltype(getEndomorphismOpOperand)>,
96 std::decay_t<decltype(getEndomorphismOpResult)>,
97 std::decay_t<decltype(getAlgebraicOpOperands)>,
98 std::decay_t<decltype(getAlgebraicOpResult)>,
99 std::decay_t<decltype(isEndomorphismOp)>,
100 std::decay_t<decltype(isAlgebraicOp)>>;
101 patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
102 std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
103 std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
104 std::move(isEndomorphismOp), std::move(isAlgebraicOp),
105 AlgebraicOp::getOperationName(), 1, patterns.getContext()));
106 }
107
108 // It is invalid to change ops that declare symbols during the application of
109 // these patterns, because symbolTableCollection is used to cache them.
110 void populateSimplificationPatterns(
111 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
112 void populateFoldingPatterns(RewritePatternSet &patterns,
113 SymbolTableCollection &symbolTableCollection);
114
115 } // namespace mesh
116 } // namespace mlir
117
118 #endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
119