xref: /llvm-project/mlir/lib/Transforms/Utils/CommutativityUtils.cpp (revision 5fcf907b34355980f77d7665a175b05fea7a6b7b)
1b508c564Ssrishti-cb //===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
2b508c564Ssrishti-cb //
3b508c564Ssrishti-cb // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b508c564Ssrishti-cb // See https://llvm.org/LICENSE.txt for license information.
5b508c564Ssrishti-cb // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b508c564Ssrishti-cb //
7b508c564Ssrishti-cb //===----------------------------------------------------------------------===//
8b508c564Ssrishti-cb //
9b508c564Ssrishti-cb // This file implements a commutativity utility pattern and a function to
10b508c564Ssrishti-cb // populate this pattern. The function is intended to be used inside passes to
11b508c564Ssrishti-cb // simplify the matching of commutative operations by fixing the order of their
12b508c564Ssrishti-cb // operands.
13b508c564Ssrishti-cb //
14b508c564Ssrishti-cb //===----------------------------------------------------------------------===//
15b508c564Ssrishti-cb 
16b508c564Ssrishti-cb #include "mlir/Transforms/CommutativityUtils.h"
17b508c564Ssrishti-cb 
18b508c564Ssrishti-cb #include <queue>
19b508c564Ssrishti-cb 
20b508c564Ssrishti-cb using namespace mlir;
21b508c564Ssrishti-cb 
22b508c564Ssrishti-cb /// The possible "types" of ancestors. Here, an ancestor is an op or a block
23b508c564Ssrishti-cb /// argument present in the backward slice of a value.
24b508c564Ssrishti-cb enum AncestorType {
25b508c564Ssrishti-cb   /// Pertains to a block argument.
26b508c564Ssrishti-cb   BLOCK_ARGUMENT,
27b508c564Ssrishti-cb 
28b508c564Ssrishti-cb   /// Pertains to a non-constant-like op.
29b508c564Ssrishti-cb   NON_CONSTANT_OP,
30b508c564Ssrishti-cb 
31b508c564Ssrishti-cb   /// Pertains to a constant-like op.
32b508c564Ssrishti-cb   CONSTANT_OP
33b508c564Ssrishti-cb };
34b508c564Ssrishti-cb 
35b508c564Ssrishti-cb /// Stores the "key" associated with an ancestor.
36b508c564Ssrishti-cb struct AncestorKey {
37b508c564Ssrishti-cb   /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
38b508c564Ssrishti-cb   /// the ancestor.
39b508c564Ssrishti-cb   AncestorType type;
40b508c564Ssrishti-cb 
41b508c564Ssrishti-cb   /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
42b508c564Ssrishti-cb   /// `CONSTANT_OP`. Else, holds "".
43b508c564Ssrishti-cb   StringRef opName;
44b508c564Ssrishti-cb 
45b508c564Ssrishti-cb   /// Constructor for `AncestorKey`.
AncestorKeyAncestorKey46b508c564Ssrishti-cb   AncestorKey(Operation *op) {
47b508c564Ssrishti-cb     if (!op) {
48b508c564Ssrishti-cb       type = BLOCK_ARGUMENT;
49b508c564Ssrishti-cb     } else {
50b508c564Ssrishti-cb       type =
51b508c564Ssrishti-cb           op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP;
52b508c564Ssrishti-cb       opName = op->getName().getStringRef();
53b508c564Ssrishti-cb     }
54b508c564Ssrishti-cb   }
55b508c564Ssrishti-cb 
56b508c564Ssrishti-cb   /// Overloaded operator `<` for `AncestorKey`.
57b508c564Ssrishti-cb   ///
58b508c564Ssrishti-cb   /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
59b508c564Ssrishti-cb   /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
60b508c564Ssrishti-cb   /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
61b508c564Ssrishti-cb   /// ones are the ones with smaller op names (lexicographically).
62b508c564Ssrishti-cb   ///
63b508c564Ssrishti-cb   /// TODO: Include other information like attributes, value type, etc., to
64b508c564Ssrishti-cb   /// enhance this comparison. For example, currently this comparison doesn't
65b508c564Ssrishti-cb   /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
66b508c564Ssrishti-cb   /// `addi (in i64)`. Such an enhancement should only be done if the need
67b508c564Ssrishti-cb   /// arises.
operator <AncestorKey68b508c564Ssrishti-cb   bool operator<(const AncestorKey &key) const {
69b508c564Ssrishti-cb     return std::tie(type, opName) < std::tie(key.type, key.opName);
70b508c564Ssrishti-cb   }
71b508c564Ssrishti-cb };
72b508c564Ssrishti-cb 
73b508c564Ssrishti-cb /// Stores a commutative operand along with its BFS traversal information.
74b508c564Ssrishti-cb struct CommutativeOperand {
75b508c564Ssrishti-cb   /// Stores the operand.
76b508c564Ssrishti-cb   Value operand;
77b508c564Ssrishti-cb 
78b508c564Ssrishti-cb   /// Stores the queue of ancestors of the operand's BFS traversal at a
79b508c564Ssrishti-cb   /// particular point in time.
80b508c564Ssrishti-cb   std::queue<Operation *> ancestorQueue;
81b508c564Ssrishti-cb 
82b508c564Ssrishti-cb   /// Stores the list of ancestors that have been visited by the BFS traversal
83b508c564Ssrishti-cb   /// at a particular point in time.
84b508c564Ssrishti-cb   DenseSet<Operation *> visitedAncestors;
85b508c564Ssrishti-cb 
86b508c564Ssrishti-cb   /// Stores the operand's "key". This "key" is defined as a list of the
87b508c564Ssrishti-cb   /// "AncestorKeys" associated with the ancestors of this operand, in a
88b508c564Ssrishti-cb   /// breadth-first order.
89b508c564Ssrishti-cb   ///
90b508c564Ssrishti-cb   /// So, if an operand, say `A`, was produced as follows:
91b508c564Ssrishti-cb   ///
92b508c564Ssrishti-cb   /// `<block argument>`  `<block argument>`
93b508c564Ssrishti-cb   ///             \          /
94b508c564Ssrishti-cb   ///              \        /
95b508c564Ssrishti-cb   ///             `arith.subi`           `arith.constant`
96b508c564Ssrishti-cb   ///                       \            /
97b508c564Ssrishti-cb   ///                        `arith.addi`
98b508c564Ssrishti-cb   ///                              |
99b508c564Ssrishti-cb   ///                         returns `A`
100b508c564Ssrishti-cb   ///
101b508c564Ssrishti-cb   /// Then, the ancestors of `A`, in the breadth-first order are:
102b508c564Ssrishti-cb   /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
103b508c564Ssrishti-cb   /// `<block argument>`.
104b508c564Ssrishti-cb   ///
105b508c564Ssrishti-cb   /// Thus, the "key" associated with operand `A` is:
106b508c564Ssrishti-cb   /// {
107b508c564Ssrishti-cb   ///  {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
108b508c564Ssrishti-cb   ///  {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
109b508c564Ssrishti-cb   ///  {type: `CONSTANT_OP`, opName: "arith.constant"},
110b508c564Ssrishti-cb   ///  {type: `BLOCK_ARGUMENT`, opName: ""},
111b508c564Ssrishti-cb   ///  {type: `BLOCK_ARGUMENT`, opName: ""}
112b508c564Ssrishti-cb   /// }
113b508c564Ssrishti-cb   SmallVector<AncestorKey, 4> key;
114b508c564Ssrishti-cb 
115b508c564Ssrishti-cb   /// Push an ancestor into the operand's BFS information structure. This
116b508c564Ssrishti-cb   /// entails it being pushed into the queue (always) and inserted into the
117b508c564Ssrishti-cb   /// "visited ancestors" list (iff it is an op rather than a block argument).
pushAncestorCommutativeOperand118b508c564Ssrishti-cb   void pushAncestor(Operation *op) {
119b508c564Ssrishti-cb     ancestorQueue.push(op);
120b508c564Ssrishti-cb     if (op)
121b508c564Ssrishti-cb       visitedAncestors.insert(op);
122b508c564Ssrishti-cb   }
123b508c564Ssrishti-cb 
124b508c564Ssrishti-cb   /// Refresh the key.
125b508c564Ssrishti-cb   ///
126b508c564Ssrishti-cb   /// Refreshing a key entails making it up-to-date with the operand's BFS
127b508c564Ssrishti-cb   /// traversal that has happened till that point in time, i.e, appending the
128b508c564Ssrishti-cb   /// existing key with the front ancestor's "AncestorKey". Note that a key
129b508c564Ssrishti-cb   /// directly reflects the BFS and thus needs to be refreshed during the
130b508c564Ssrishti-cb   /// progression of the traversal.
refreshKeyCommutativeOperand131b508c564Ssrishti-cb   void refreshKey() {
132b508c564Ssrishti-cb     if (ancestorQueue.empty())
133b508c564Ssrishti-cb       return;
134b508c564Ssrishti-cb 
135b508c564Ssrishti-cb     Operation *frontAncestor = ancestorQueue.front();
136b508c564Ssrishti-cb     AncestorKey frontAncestorKey(frontAncestor);
137b508c564Ssrishti-cb     key.push_back(frontAncestorKey);
138b508c564Ssrishti-cb   }
139b508c564Ssrishti-cb 
140b508c564Ssrishti-cb   /// Pop the front ancestor, if any, from the queue and then push its adjacent
141b508c564Ssrishti-cb   /// unvisited ancestors, if any, to the queue (this is the main body of the
142b508c564Ssrishti-cb   /// BFS algorithm).
popFrontAndPushAdjacentUnvisitedAncestorsCommutativeOperand143b508c564Ssrishti-cb   void popFrontAndPushAdjacentUnvisitedAncestors() {
144b508c564Ssrishti-cb     if (ancestorQueue.empty())
145b508c564Ssrishti-cb       return;
146b508c564Ssrishti-cb     Operation *frontAncestor = ancestorQueue.front();
147b508c564Ssrishti-cb     ancestorQueue.pop();
148b508c564Ssrishti-cb     if (!frontAncestor)
149b508c564Ssrishti-cb       return;
150b508c564Ssrishti-cb     for (Value operand : frontAncestor->getOperands()) {
151b508c564Ssrishti-cb       Operation *operandDefOp = operand.getDefiningOp();
152b508c564Ssrishti-cb       if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
153b508c564Ssrishti-cb         pushAncestor(operandDefOp);
154b508c564Ssrishti-cb     }
155b508c564Ssrishti-cb   }
156b508c564Ssrishti-cb };
157b508c564Ssrishti-cb 
158b508c564Ssrishti-cb /// Sorts the operands of `op` in ascending order of the "key" associated with
159b508c564Ssrishti-cb /// each operand iff `op` is commutative. This is a stable sort.
160b508c564Ssrishti-cb ///
161b508c564Ssrishti-cb /// After the application of this pattern, since the commutative operands now
162b508c564Ssrishti-cb /// have a deterministic order in which they occur in an op, the matching of
163b508c564Ssrishti-cb /// large DAGs becomes much simpler, i.e., requires much less number of checks
164b508c564Ssrishti-cb /// to be written by a user in her/his pattern matching function.
165b508c564Ssrishti-cb ///
166b508c564Ssrishti-cb /// Some examples of such a sorting:
167b508c564Ssrishti-cb ///
168b508c564Ssrishti-cb /// Assume that the sorting is being applied to `foo.commutative`, which is a
169b508c564Ssrishti-cb /// commutative op.
170b508c564Ssrishti-cb ///
171b508c564Ssrishti-cb /// Example 1:
172b508c564Ssrishti-cb ///
173b508c564Ssrishti-cb /// %1 = foo.const 0
174b508c564Ssrishti-cb /// %2 = foo.mul <block argument>, <block argument>
175b508c564Ssrishti-cb /// %3 = foo.commutative %1, %2
176b508c564Ssrishti-cb ///
177b508c564Ssrishti-cb /// Here,
178b508c564Ssrishti-cb /// 1. The key associated with %1 is:
179b508c564Ssrishti-cb ///     `{
180b508c564Ssrishti-cb ///       {CONSTANT_OP, "foo.const"}
181b508c564Ssrishti-cb ///      }`
182b508c564Ssrishti-cb /// 2. The key associated with %2 is:
183b508c564Ssrishti-cb ///     `{
184b508c564Ssrishti-cb ///       {NON_CONSTANT_OP, "foo.mul"},
185b508c564Ssrishti-cb ///       {BLOCK_ARGUMENT, ""},
186b508c564Ssrishti-cb ///       {BLOCK_ARGUMENT, ""}
187b508c564Ssrishti-cb ///      }`
188b508c564Ssrishti-cb ///
189b508c564Ssrishti-cb /// The key of %2 < the key of %1
190b508c564Ssrishti-cb /// Thus, the sorted `foo.commutative` is:
191b508c564Ssrishti-cb /// %3 = foo.commutative %2, %1
192b508c564Ssrishti-cb ///
193b508c564Ssrishti-cb /// Example 2:
194b508c564Ssrishti-cb ///
195b508c564Ssrishti-cb /// %1 = foo.const 0
196b508c564Ssrishti-cb /// %2 = foo.mul <block argument>, <block argument>
197b508c564Ssrishti-cb /// %3 = foo.mul %2, %1
198b508c564Ssrishti-cb /// %4 = foo.add %2, %1
199b508c564Ssrishti-cb /// %5 = foo.commutative %1, %2, %3, %4
200b508c564Ssrishti-cb ///
201b508c564Ssrishti-cb /// Here,
202b508c564Ssrishti-cb /// 1. The key associated with %1 is:
203b508c564Ssrishti-cb ///     `{
204b508c564Ssrishti-cb ///       {CONSTANT_OP, "foo.const"}
205b508c564Ssrishti-cb ///      }`
206b508c564Ssrishti-cb /// 2. The key associated with %2 is:
207b508c564Ssrishti-cb ///     `{
208b508c564Ssrishti-cb ///       {NON_CONSTANT_OP, "foo.mul"},
209b508c564Ssrishti-cb ///       {BLOCK_ARGUMENT, ""}
210b508c564Ssrishti-cb ///      }`
211b508c564Ssrishti-cb /// 3. The key associated with %3 is:
212b508c564Ssrishti-cb ///     `{
213b508c564Ssrishti-cb ///       {NON_CONSTANT_OP, "foo.mul"},
214b508c564Ssrishti-cb ///       {NON_CONSTANT_OP, "foo.mul"},
215b508c564Ssrishti-cb ///       {CONSTANT_OP, "foo.const"},
216b508c564Ssrishti-cb ///       {BLOCK_ARGUMENT, ""},
217b508c564Ssrishti-cb ///       {BLOCK_ARGUMENT, ""}
218b508c564Ssrishti-cb ///      }`
219b508c564Ssrishti-cb /// 4. The key associated with %4 is:
220b508c564Ssrishti-cb ///     `{
221b508c564Ssrishti-cb ///       {NON_CONSTANT_OP, "foo.add"},
222b508c564Ssrishti-cb ///       {NON_CONSTANT_OP, "foo.mul"},
223b508c564Ssrishti-cb ///       {CONSTANT_OP, "foo.const"},
224b508c564Ssrishti-cb ///       {BLOCK_ARGUMENT, ""},
225b508c564Ssrishti-cb ///       {BLOCK_ARGUMENT, ""}
226b508c564Ssrishti-cb ///      }`
227b508c564Ssrishti-cb ///
228b508c564Ssrishti-cb /// Thus, the sorted `foo.commutative` is:
229b508c564Ssrishti-cb /// %5 = foo.commutative %4, %3, %2, %1
230b508c564Ssrishti-cb class SortCommutativeOperands : public RewritePattern {
231b508c564Ssrishti-cb public:
SortCommutativeOperands(MLIRContext * context)232b508c564Ssrishti-cb   SortCommutativeOperands(MLIRContext *context)
233b508c564Ssrishti-cb       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const234b508c564Ssrishti-cb   LogicalResult matchAndRewrite(Operation *op,
235b508c564Ssrishti-cb                                 PatternRewriter &rewriter) const override {
236b508c564Ssrishti-cb     // Custom comparator for two commutative operands, which returns true iff
237b508c564Ssrishti-cb     // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
238b508c564Ssrishti-cb     // i.e.,
239b508c564Ssrishti-cb     // 1. In the first unequal pair of corresponding AncestorKeys, the
240b508c564Ssrishti-cb     // AncestorKey in `constCommOperandA` is smaller, or,
241b508c564Ssrishti-cb     // 2. Both the AncestorKeys in every pair are the same and the size of
242b508c564Ssrishti-cb     // `constCommOperandA`'s "key" is smaller.
243b508c564Ssrishti-cb     auto commutativeOperandComparator =
244b508c564Ssrishti-cb         [](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
245b508c564Ssrishti-cb            const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
246b508c564Ssrishti-cb           if (constCommOperandA->operand == constCommOperandB->operand)
247b508c564Ssrishti-cb             return false;
248b508c564Ssrishti-cb 
249b508c564Ssrishti-cb           auto &commOperandA =
250b508c564Ssrishti-cb               const_cast<std::unique_ptr<CommutativeOperand> &>(
251b508c564Ssrishti-cb                   constCommOperandA);
252b508c564Ssrishti-cb           auto &commOperandB =
253b508c564Ssrishti-cb               const_cast<std::unique_ptr<CommutativeOperand> &>(
254b508c564Ssrishti-cb                   constCommOperandB);
255b508c564Ssrishti-cb 
256b508c564Ssrishti-cb           // Iteratively perform the BFS's of both operands until an order among
257b508c564Ssrishti-cb           // them can be determined.
258b508c564Ssrishti-cb           unsigned keyIndex = 0;
259b508c564Ssrishti-cb           while (true) {
260b508c564Ssrishti-cb             if (commOperandA->key.size() <= keyIndex) {
261b508c564Ssrishti-cb               if (commOperandA->ancestorQueue.empty())
262b508c564Ssrishti-cb                 return true;
263b508c564Ssrishti-cb               commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
264b508c564Ssrishti-cb               commOperandA->refreshKey();
265b508c564Ssrishti-cb             }
266b508c564Ssrishti-cb             if (commOperandB->key.size() <= keyIndex) {
267b508c564Ssrishti-cb               if (commOperandB->ancestorQueue.empty())
268b508c564Ssrishti-cb                 return false;
269b508c564Ssrishti-cb               commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
270b508c564Ssrishti-cb               commOperandB->refreshKey();
271b508c564Ssrishti-cb             }
272b508c564Ssrishti-cb             if (commOperandA->ancestorQueue.empty() ||
273b508c564Ssrishti-cb                 commOperandB->ancestorQueue.empty())
274b508c564Ssrishti-cb               return commOperandA->key.size() < commOperandB->key.size();
275b508c564Ssrishti-cb             if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
276b508c564Ssrishti-cb               return true;
277b508c564Ssrishti-cb             if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
278b508c564Ssrishti-cb               return false;
279b508c564Ssrishti-cb             keyIndex++;
280b508c564Ssrishti-cb           }
281b508c564Ssrishti-cb         };
282b508c564Ssrishti-cb 
283b508c564Ssrishti-cb     // If `op` is not commutative, do nothing.
284b508c564Ssrishti-cb     if (!op->hasTrait<OpTrait::IsCommutative>())
285b508c564Ssrishti-cb       return failure();
286b508c564Ssrishti-cb 
287b508c564Ssrishti-cb     // Populate the list of commutative operands.
288b508c564Ssrishti-cb     SmallVector<Value, 2> operands = op->getOperands();
289b508c564Ssrishti-cb     SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands;
290b508c564Ssrishti-cb     for (Value operand : operands) {
291b508c564Ssrishti-cb       std::unique_ptr<CommutativeOperand> commOperand =
292b508c564Ssrishti-cb           std::make_unique<CommutativeOperand>();
293b508c564Ssrishti-cb       commOperand->operand = operand;
294b508c564Ssrishti-cb       commOperand->pushAncestor(operand.getDefiningOp());
295b508c564Ssrishti-cb       commOperand->refreshKey();
296b508c564Ssrishti-cb       commOperands.push_back(std::move(commOperand));
297b508c564Ssrishti-cb     }
298b508c564Ssrishti-cb 
299b508c564Ssrishti-cb     // Sort the operands.
300b508c564Ssrishti-cb     std::stable_sort(commOperands.begin(), commOperands.end(),
301b508c564Ssrishti-cb                      commutativeOperandComparator);
302b508c564Ssrishti-cb     SmallVector<Value, 2> sortedOperands;
303b508c564Ssrishti-cb     for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
304b508c564Ssrishti-cb       sortedOperands.push_back(commOperand->operand);
305b508c564Ssrishti-cb     if (sortedOperands == operands)
306b508c564Ssrishti-cb       return failure();
307*5fcf907bSMatthias Springer     rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
308b508c564Ssrishti-cb     return success();
309b508c564Ssrishti-cb   }
310b508c564Ssrishti-cb };
311b508c564Ssrishti-cb 
populateCommutativityUtilsPatterns(RewritePatternSet & patterns)312b508c564Ssrishti-cb void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {
313b508c564Ssrishti-cb   patterns.add<SortCommutativeOperands>(patterns.getContext());
314b508c564Ssrishti-cb }
315