xref: /llvm-project/mlir/lib/Transforms/Utils/CommutativityUtils.cpp (revision 5fcf907b34355980f77d7665a175b05fea7a6b7b)
1 //===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a commutativity utility pattern and a function to
10 // populate this pattern. The function is intended to be used inside passes to
11 // simplify the matching of commutative operations by fixing the order of their
12 // operands.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Transforms/CommutativityUtils.h"
17 
18 #include <queue>
19 
20 using namespace mlir;
21 
22 /// The possible "types" of ancestors. Here, an ancestor is an op or a block
23 /// argument present in the backward slice of a value.
24 enum AncestorType {
25   /// Pertains to a block argument.
26   BLOCK_ARGUMENT,
27 
28   /// Pertains to a non-constant-like op.
29   NON_CONSTANT_OP,
30 
31   /// Pertains to a constant-like op.
32   CONSTANT_OP
33 };
34 
35 /// Stores the "key" associated with an ancestor.
36 struct AncestorKey {
37   /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
38   /// the ancestor.
39   AncestorType type;
40 
41   /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
42   /// `CONSTANT_OP`. Else, holds "".
43   StringRef opName;
44 
45   /// Constructor for `AncestorKey`.
AncestorKeyAncestorKey46   AncestorKey(Operation *op) {
47     if (!op) {
48       type = BLOCK_ARGUMENT;
49     } else {
50       type =
51           op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP;
52       opName = op->getName().getStringRef();
53     }
54   }
55 
56   /// Overloaded operator `<` for `AncestorKey`.
57   ///
58   /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
59   /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
60   /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
61   /// ones are the ones with smaller op names (lexicographically).
62   ///
63   /// TODO: Include other information like attributes, value type, etc., to
64   /// enhance this comparison. For example, currently this comparison doesn't
65   /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
66   /// `addi (in i64)`. Such an enhancement should only be done if the need
67   /// arises.
operator <AncestorKey68   bool operator<(const AncestorKey &key) const {
69     return std::tie(type, opName) < std::tie(key.type, key.opName);
70   }
71 };
72 
73 /// Stores a commutative operand along with its BFS traversal information.
74 struct CommutativeOperand {
75   /// Stores the operand.
76   Value operand;
77 
78   /// Stores the queue of ancestors of the operand's BFS traversal at a
79   /// particular point in time.
80   std::queue<Operation *> ancestorQueue;
81 
82   /// Stores the list of ancestors that have been visited by the BFS traversal
83   /// at a particular point in time.
84   DenseSet<Operation *> visitedAncestors;
85 
86   /// Stores the operand's "key". This "key" is defined as a list of the
87   /// "AncestorKeys" associated with the ancestors of this operand, in a
88   /// breadth-first order.
89   ///
90   /// So, if an operand, say `A`, was produced as follows:
91   ///
92   /// `<block argument>`  `<block argument>`
93   ///             \          /
94   ///              \        /
95   ///             `arith.subi`           `arith.constant`
96   ///                       \            /
97   ///                        `arith.addi`
98   ///                              |
99   ///                         returns `A`
100   ///
101   /// Then, the ancestors of `A`, in the breadth-first order are:
102   /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
103   /// `<block argument>`.
104   ///
105   /// Thus, the "key" associated with operand `A` is:
106   /// {
107   ///  {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
108   ///  {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
109   ///  {type: `CONSTANT_OP`, opName: "arith.constant"},
110   ///  {type: `BLOCK_ARGUMENT`, opName: ""},
111   ///  {type: `BLOCK_ARGUMENT`, opName: ""}
112   /// }
113   SmallVector<AncestorKey, 4> key;
114 
115   /// Push an ancestor into the operand's BFS information structure. This
116   /// entails it being pushed into the queue (always) and inserted into the
117   /// "visited ancestors" list (iff it is an op rather than a block argument).
pushAncestorCommutativeOperand118   void pushAncestor(Operation *op) {
119     ancestorQueue.push(op);
120     if (op)
121       visitedAncestors.insert(op);
122   }
123 
124   /// Refresh the key.
125   ///
126   /// Refreshing a key entails making it up-to-date with the operand's BFS
127   /// traversal that has happened till that point in time, i.e, appending the
128   /// existing key with the front ancestor's "AncestorKey". Note that a key
129   /// directly reflects the BFS and thus needs to be refreshed during the
130   /// progression of the traversal.
refreshKeyCommutativeOperand131   void refreshKey() {
132     if (ancestorQueue.empty())
133       return;
134 
135     Operation *frontAncestor = ancestorQueue.front();
136     AncestorKey frontAncestorKey(frontAncestor);
137     key.push_back(frontAncestorKey);
138   }
139 
140   /// Pop the front ancestor, if any, from the queue and then push its adjacent
141   /// unvisited ancestors, if any, to the queue (this is the main body of the
142   /// BFS algorithm).
popFrontAndPushAdjacentUnvisitedAncestorsCommutativeOperand143   void popFrontAndPushAdjacentUnvisitedAncestors() {
144     if (ancestorQueue.empty())
145       return;
146     Operation *frontAncestor = ancestorQueue.front();
147     ancestorQueue.pop();
148     if (!frontAncestor)
149       return;
150     for (Value operand : frontAncestor->getOperands()) {
151       Operation *operandDefOp = operand.getDefiningOp();
152       if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
153         pushAncestor(operandDefOp);
154     }
155   }
156 };
157 
158 /// Sorts the operands of `op` in ascending order of the "key" associated with
159 /// each operand iff `op` is commutative. This is a stable sort.
160 ///
161 /// After the application of this pattern, since the commutative operands now
162 /// have a deterministic order in which they occur in an op, the matching of
163 /// large DAGs becomes much simpler, i.e., requires much less number of checks
164 /// to be written by a user in her/his pattern matching function.
165 ///
166 /// Some examples of such a sorting:
167 ///
168 /// Assume that the sorting is being applied to `foo.commutative`, which is a
169 /// commutative op.
170 ///
171 /// Example 1:
172 ///
173 /// %1 = foo.const 0
174 /// %2 = foo.mul <block argument>, <block argument>
175 /// %3 = foo.commutative %1, %2
176 ///
177 /// Here,
178 /// 1. The key associated with %1 is:
179 ///     `{
180 ///       {CONSTANT_OP, "foo.const"}
181 ///      }`
182 /// 2. The key associated with %2 is:
183 ///     `{
184 ///       {NON_CONSTANT_OP, "foo.mul"},
185 ///       {BLOCK_ARGUMENT, ""},
186 ///       {BLOCK_ARGUMENT, ""}
187 ///      }`
188 ///
189 /// The key of %2 < the key of %1
190 /// Thus, the sorted `foo.commutative` is:
191 /// %3 = foo.commutative %2, %1
192 ///
193 /// Example 2:
194 ///
195 /// %1 = foo.const 0
196 /// %2 = foo.mul <block argument>, <block argument>
197 /// %3 = foo.mul %2, %1
198 /// %4 = foo.add %2, %1
199 /// %5 = foo.commutative %1, %2, %3, %4
200 ///
201 /// Here,
202 /// 1. The key associated with %1 is:
203 ///     `{
204 ///       {CONSTANT_OP, "foo.const"}
205 ///      }`
206 /// 2. The key associated with %2 is:
207 ///     `{
208 ///       {NON_CONSTANT_OP, "foo.mul"},
209 ///       {BLOCK_ARGUMENT, ""}
210 ///      }`
211 /// 3. The key associated with %3 is:
212 ///     `{
213 ///       {NON_CONSTANT_OP, "foo.mul"},
214 ///       {NON_CONSTANT_OP, "foo.mul"},
215 ///       {CONSTANT_OP, "foo.const"},
216 ///       {BLOCK_ARGUMENT, ""},
217 ///       {BLOCK_ARGUMENT, ""}
218 ///      }`
219 /// 4. The key associated with %4 is:
220 ///     `{
221 ///       {NON_CONSTANT_OP, "foo.add"},
222 ///       {NON_CONSTANT_OP, "foo.mul"},
223 ///       {CONSTANT_OP, "foo.const"},
224 ///       {BLOCK_ARGUMENT, ""},
225 ///       {BLOCK_ARGUMENT, ""}
226 ///      }`
227 ///
228 /// Thus, the sorted `foo.commutative` is:
229 /// %5 = foo.commutative %4, %3, %2, %1
230 class SortCommutativeOperands : public RewritePattern {
231 public:
SortCommutativeOperands(MLIRContext * context)232   SortCommutativeOperands(MLIRContext *context)
233       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const234   LogicalResult matchAndRewrite(Operation *op,
235                                 PatternRewriter &rewriter) const override {
236     // Custom comparator for two commutative operands, which returns true iff
237     // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
238     // i.e.,
239     // 1. In the first unequal pair of corresponding AncestorKeys, the
240     // AncestorKey in `constCommOperandA` is smaller, or,
241     // 2. Both the AncestorKeys in every pair are the same and the size of
242     // `constCommOperandA`'s "key" is smaller.
243     auto commutativeOperandComparator =
244         [](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
245            const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
246           if (constCommOperandA->operand == constCommOperandB->operand)
247             return false;
248 
249           auto &commOperandA =
250               const_cast<std::unique_ptr<CommutativeOperand> &>(
251                   constCommOperandA);
252           auto &commOperandB =
253               const_cast<std::unique_ptr<CommutativeOperand> &>(
254                   constCommOperandB);
255 
256           // Iteratively perform the BFS's of both operands until an order among
257           // them can be determined.
258           unsigned keyIndex = 0;
259           while (true) {
260             if (commOperandA->key.size() <= keyIndex) {
261               if (commOperandA->ancestorQueue.empty())
262                 return true;
263               commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
264               commOperandA->refreshKey();
265             }
266             if (commOperandB->key.size() <= keyIndex) {
267               if (commOperandB->ancestorQueue.empty())
268                 return false;
269               commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
270               commOperandB->refreshKey();
271             }
272             if (commOperandA->ancestorQueue.empty() ||
273                 commOperandB->ancestorQueue.empty())
274               return commOperandA->key.size() < commOperandB->key.size();
275             if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
276               return true;
277             if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
278               return false;
279             keyIndex++;
280           }
281         };
282 
283     // If `op` is not commutative, do nothing.
284     if (!op->hasTrait<OpTrait::IsCommutative>())
285       return failure();
286 
287     // Populate the list of commutative operands.
288     SmallVector<Value, 2> operands = op->getOperands();
289     SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands;
290     for (Value operand : operands) {
291       std::unique_ptr<CommutativeOperand> commOperand =
292           std::make_unique<CommutativeOperand>();
293       commOperand->operand = operand;
294       commOperand->pushAncestor(operand.getDefiningOp());
295       commOperand->refreshKey();
296       commOperands.push_back(std::move(commOperand));
297     }
298 
299     // Sort the operands.
300     std::stable_sort(commOperands.begin(), commOperands.end(),
301                      commutativeOperandComparator);
302     SmallVector<Value, 2> sortedOperands;
303     for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
304       sortedOperands.push_back(commOperand->operand);
305     if (sortedOperands == operands)
306       return failure();
307     rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
308     return success();
309   }
310 };
311 
populateCommutativityUtilsPatterns(RewritePatternSet & patterns)312 void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {
313   patterns.add<SortCommutativeOperands>(patterns.getContext());
314 }
315