xref: /llvm-project/mlir/lib/Analysis/SliceAnalysis.cpp (revision d97bc388fd9ef8bc38353f93ff42d894ddc4a271)
1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Analysis/TopologicalSortUtils.h"
15 #include "mlir/IR/Block.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Interfaces/SideEffectInterfaces.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 
23 ///
24 /// Implements Analysis functions specific to slicing in Function.
25 ///
26 
27 using namespace mlir;
28 
29 static void
30 getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
31                     const SliceOptions::TransitiveFilter &filter = nullptr) {
32   if (!op)
33     return;
34 
35   // Evaluate whether we should keep this use.
36   // This is useful in particular to implement scoping; i.e. return the
37   // transitive forwardSlice in the current scope.
38   if (filter && !filter(op))
39     return;
40 
41   for (Region &region : op->getRegions())
42     for (Block &block : region)
43       for (Operation &blockOp : block)
44         if (forwardSlice->count(&blockOp) == 0)
45           getForwardSliceImpl(&blockOp, forwardSlice, filter);
46   for (Value result : op->getResults()) {
47     for (Operation *userOp : result.getUsers())
48       if (forwardSlice->count(userOp) == 0)
49         getForwardSliceImpl(userOp, forwardSlice, filter);
50   }
51 
52   forwardSlice->insert(op);
53 }
54 
55 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
56                            const ForwardSliceOptions &options) {
57   getForwardSliceImpl(op, forwardSlice, options.filter);
58   if (!options.inclusive) {
59     // Don't insert the top level operation, we just queried on it and don't
60     // want it in the results.
61     forwardSlice->remove(op);
62   }
63 
64   // Reverse to get back the actual topological order.
65   // std::reverse does not work out of the box on SetVector and I want an
66   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
67   SmallVector<Operation *, 0> v(forwardSlice->takeVector());
68   forwardSlice->insert(v.rbegin(), v.rend());
69 }
70 
71 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
72                            const SliceOptions &options) {
73   for (Operation *user : root.getUsers())
74     getForwardSliceImpl(user, forwardSlice, options.filter);
75 
76   // Reverse to get back the actual topological order.
77   // std::reverse does not work out of the box on SetVector and I want an
78   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
79   SmallVector<Operation *, 0> v(forwardSlice->takeVector());
80   forwardSlice->insert(v.rbegin(), v.rend());
81 }
82 
83 static void getBackwardSliceImpl(Operation *op,
84                                  SetVector<Operation *> *backwardSlice,
85                                  const BackwardSliceOptions &options) {
86   if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
87     return;
88 
89   // Evaluate whether we should keep this def.
90   // This is useful in particular to implement scoping; i.e. return the
91   // transitive backwardSlice in the current scope.
92   if (options.filter && !options.filter(op))
93     return;
94 
95   auto processValue = [&](Value value) {
96     if (auto *definingOp = value.getDefiningOp()) {
97       if (backwardSlice->count(definingOp) == 0)
98         getBackwardSliceImpl(definingOp, backwardSlice, options);
99     } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
100       if (options.omitBlockArguments)
101         return;
102 
103       Block *block = blockArg.getOwner();
104       Operation *parentOp = block->getParentOp();
105       // TODO: determine whether we want to recurse backward into the other
106       // blocks of parentOp, which are not technically backward unless they flow
107       // into us. For now, just bail.
108       if (parentOp && backwardSlice->count(parentOp) == 0) {
109         assert(parentOp->getNumRegions() == 1 &&
110                parentOp->getRegion(0).getBlocks().size() == 1);
111         getBackwardSliceImpl(parentOp, backwardSlice, options);
112       }
113     } else {
114       llvm_unreachable("No definingOp and not a block argument.");
115     }
116   };
117 
118   if (!options.omitUsesFromAbove) {
119     llvm::for_each(op->getRegions(), [&](Region &region) {
120       // Walk this region recursively to collect the regions that descend from
121       // this op's nested regions (inclusive).
122       SmallPtrSet<Region *, 4> descendents;
123       region.walk(
124           [&](Region *childRegion) { descendents.insert(childRegion); });
125       region.walk([&](Operation *op) {
126         for (OpOperand &operand : op->getOpOperands()) {
127           if (!descendents.contains(operand.get().getParentRegion()))
128             processValue(operand.get());
129         }
130       });
131     });
132   }
133   llvm::for_each(op->getOperands(), processValue);
134 
135   backwardSlice->insert(op);
136 }
137 
138 void mlir::getBackwardSlice(Operation *op,
139                             SetVector<Operation *> *backwardSlice,
140                             const BackwardSliceOptions &options) {
141   getBackwardSliceImpl(op, backwardSlice, options);
142 
143   if (!options.inclusive) {
144     // Don't insert the top level operation, we just queried on it and don't
145     // want it in the results.
146     backwardSlice->remove(op);
147   }
148 }
149 
150 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
151                             const BackwardSliceOptions &options) {
152   if (Operation *definingOp = root.getDefiningOp()) {
153     getBackwardSlice(definingOp, backwardSlice, options);
154     return;
155   }
156   Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
157   getBackwardSlice(bbAargOwner, backwardSlice, options);
158 }
159 
160 SetVector<Operation *>
161 mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
162                const ForwardSliceOptions &forwardSliceOptions) {
163   SetVector<Operation *> slice;
164   slice.insert(op);
165 
166   unsigned currentIndex = 0;
167   SetVector<Operation *> backwardSlice;
168   SetVector<Operation *> forwardSlice;
169   while (currentIndex != slice.size()) {
170     auto *currentOp = (slice)[currentIndex];
171     // Compute and insert the backwardSlice starting from currentOp.
172     backwardSlice.clear();
173     getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
174     slice.insert(backwardSlice.begin(), backwardSlice.end());
175 
176     // Compute and insert the forwardSlice starting from currentOp.
177     forwardSlice.clear();
178     getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
179     slice.insert(forwardSlice.begin(), forwardSlice.end());
180     ++currentIndex;
181   }
182   return topologicalSort(slice);
183 }
184 
185 /// Returns true if `value` (transitively) depends on iteration-carried values
186 /// of the given `ancestorOp`.
187 static bool dependsOnCarriedVals(Value value,
188                                  ArrayRef<BlockArgument> iterCarriedArgs,
189                                  Operation *ancestorOp) {
190   // Compute the backward slice of the value.
191   SetVector<Operation *> slice;
192   BackwardSliceOptions sliceOptions;
193   sliceOptions.filter = [&](Operation *op) {
194     return !ancestorOp->isAncestor(op);
195   };
196   getBackwardSlice(value, &slice, sliceOptions);
197 
198   // Check that none of the operands of the operations in the backward slice are
199   // loop iteration arguments, and neither is the value itself.
200   SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
201                                           iterCarriedArgs.end());
202   if (iterCarriedValSet.contains(value))
203     return true;
204 
205   for (Operation *op : slice)
206     for (Value operand : op->getOperands())
207       if (iterCarriedValSet.contains(operand))
208         return true;
209 
210   return false;
211 }
212 
213 /// Utility to match a generic reduction given a list of iteration-carried
214 /// arguments, `iterCarriedArgs` and the position of the potential reduction
215 /// argument within the list, `redPos`. If a reduction is matched, returns the
216 /// reduced value and the topologically-sorted list of combiner operations
217 /// involved in the reduction. Otherwise, returns a null value.
218 ///
219 /// The matching algorithm relies on the following invariants, which are subject
220 /// to change:
221 ///  1. The first combiner operation must be a binary operation with the
222 ///     iteration-carried value and the reduced value as operands.
223 ///  2. The iteration-carried value and combiner operations must be side
224 ///     effect-free, have single result and a single use.
225 ///  3. Combiner operations must be immediately nested in the region op
226 ///     performing the reduction.
227 ///  4. Reduction def-use chain must end in a terminator op that yields the
228 ///     next iteration/output values in the same order as the iteration-carried
229 ///     values in `iterCarriedArgs`.
230 ///  5. `iterCarriedArgs` must contain all the iteration-carried/output values
231 ///     of the region op performing the reduction.
232 ///
233 /// This utility is generic enough to detect reductions involving multiple
234 /// combiner operations (disabled for now) across multiple dialects, including
235 /// Linalg, Affine and SCF. For the sake of genericity, it does not return
236 /// specific enum values for the combiner operations since its goal is also
237 /// matching reductions without pre-defined semantics in core MLIR. It's up to
238 /// each client to make sense out of the list of combiner operations. It's also
239 /// up to each client to check for additional invariants on the expected
240 /// reductions not covered by this generic matching.
241 Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
242                            unsigned redPos,
243                            SmallVectorImpl<Operation *> &combinerOps) {
244   assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
245 
246   BlockArgument redCarriedVal = iterCarriedArgs[redPos];
247   if (!redCarriedVal.hasOneUse())
248     return nullptr;
249 
250   // For now, the first combiner op must be a binary op.
251   Operation *combinerOp = *redCarriedVal.getUsers().begin();
252   if (combinerOp->getNumOperands() != 2)
253     return nullptr;
254   Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
255                          ? combinerOp->getOperand(1)
256                          : combinerOp->getOperand(0);
257 
258   Operation *redRegionOp =
259       iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
260   if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
261     return nullptr;
262 
263   // Traverse the def-use chain starting from the first combiner op until a
264   // terminator is found. Gather all the combiner ops along the way in
265   // topological order.
266   while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
267     if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
268         !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
269       return nullptr;
270 
271     combinerOps.push_back(combinerOp);
272     combinerOp = *combinerOp->getUsers().begin();
273   }
274 
275   // Limit matching to single combiner op until we can properly test reductions
276   // involving multiple combiners.
277   if (combinerOps.size() != 1)
278     return nullptr;
279 
280   // Check that the yielded value is in the same position as in
281   // `iterCarriedArgs`.
282   Operation *terminatorOp = combinerOp;
283   if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
284     return nullptr;
285 
286   return reducedVal;
287 }
288