xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (revision ff2720d190e0dbd5f157d5d3614d0ab11fe9e7b2)
1 //===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
10 #define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
11 
12 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Transforms/EndomorphismSimplification.h"
15 #include "llvm/Support/Casting.h"
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <utility>
20 
21 namespace mlir {
22 
23 class SymbolTableCollection;
24 
25 namespace mesh {
26 
27 // If we have an algebraic op like "+" and a summing all-reduce,
28 // `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
29 // `all_reduce_sum(x + y)`.
30 //
31 // Another example with `min`.
32 // `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to
33 // `all_reduce_min(min(x, y))`.
34 //
35 // Works only with algebraic ops that have all their operands relevant
36 // to the all-reduce endomorphism.
37 // Will not work with some op `f(x, y, z)` where only `x` and `y` form
38 // the algebraic structure.
39 template <typename AlgebraicOp>
populateAllReduceEndomorphismSimplificationPatterns(RewritePatternSet & patterns,ReductionKind reduction)40 void populateAllReduceEndomorphismSimplificationPatterns(
41     RewritePatternSet &patterns, ReductionKind reduction) {
42   auto getEndomorphismOpOperand = [](Operation *op) {
43     auto allReduceOp = llvm::cast<AllReduceOp>(op);
44     return &allReduceOp.getInputMutable();
45   };
46   auto getEndomorphismOpResult = [](Operation *op) {
47     auto allReduceOp = llvm::cast<AllReduceOp>(op);
48     return allReduceOp->getResult(0);
49   };
50   auto getAlgebraicOpOperands = [](Operation *op,
51                                    SmallVector<OpOperand *> &operands) {
52     auto algebraicOp = llvm::cast<AlgebraicOp>(op);
53     std::transform(algebraicOp->getOpOperands().begin(),
54                    algebraicOp->getOpOperands().end(),
55                    std::back_inserter(operands),
56                    [](OpOperand &operand) { return &operand; });
57   };
58   auto getAlgebraicOpResult = [](Operation *op) {
59     auto algebraicOp = llvm::cast<AlgebraicOp>(op);
60     return algebraicOp->getResult(0);
61   };
62   auto isEndomorphismOp = [reduction](Operation *op,
63                                       std::optional<Operation *> referenceOp) {
64     auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
65     if (!allReduceOp ||
66         allReduceOp.getInput().getType().getElementType() !=
67             allReduceOp.getResult().getType().getElementType() ||
68         allReduceOp.getReduction() != reduction) {
69       return false;
70     }
71 
72     // Dont't use simplify if the all-reduce is used other than by the
73     // algebraic op.
74     // TODO: maybe handle this by an additional pass that later reverses the
75     // simplification if there are other uses left other optimizations have
76     // been done.
77     if (!allReduceOp->hasOneUse()) {
78       return false;
79     }
80 
81     if (!referenceOp) {
82       return true;
83     }
84 
85     auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
86     return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
87            allReduceOp.getInput().getType().getElementType() ==
88                refAllReduceOp.getInput().getType().getElementType();
89   };
90   auto isAlgebraicOp = [](Operation *op) {
91     return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
92   };
93 
94   using ConcreteEndomorphismSimplification = EndomorphismSimplification<
95       std::decay_t<decltype(getEndomorphismOpOperand)>,
96       std::decay_t<decltype(getEndomorphismOpResult)>,
97       std::decay_t<decltype(getAlgebraicOpOperands)>,
98       std::decay_t<decltype(getAlgebraicOpResult)>,
99       std::decay_t<decltype(isEndomorphismOp)>,
100       std::decay_t<decltype(isAlgebraicOp)>>;
101   patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
102       std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
103       std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
104       std::move(isEndomorphismOp), std::move(isAlgebraicOp),
105       AlgebraicOp::getOperationName(), 1, patterns.getContext()));
106 }
107 
108 // It is invalid to change ops that declare symbols during the application of
109 // these patterns, because symbolTableCollection is used to cache them.
110 void populateSimplificationPatterns(
111     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
112 void populateFoldingPatterns(RewritePatternSet &patterns,
113                              SymbolTableCollection &symbolTableCollection);
114 
115 } // namespace mesh
116 } // namespace mlir
117 
118 #endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
119