xref: /llvm-project/mlir/lib/Analysis/SliceAnalysis.cpp (revision d97bc388fd9ef8bc38353f93ff42d894ddc4a271)
1258dae5dSNicolas Vasilache //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
25c16564bSNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c16564bSNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c16564bSNicolas Vasilache //
969d9e990SChris Lattner // This file implements Analysis functions specific to slicing in Function.
105c16564bSNicolas Vasilache //
115c16564bSNicolas Vasilache //===----------------------------------------------------------------------===//
125c16564bSNicolas Vasilache 
135c16564bSNicolas Vasilache #include "mlir/Analysis/SliceAnalysis.h"
14b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
15b00e0c16SChristian Ulmann #include "mlir/IR/Block.h"
169ffdc930SRiver Riddle #include "mlir/IR/Operation.h"
17fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h"
180c3923e1SMehdi Amini #include "mlir/Support/LLVM.h"
19*d97bc388SIan Wood #include "llvm/ADT/STLExtras.h"
205c16564bSNicolas Vasilache #include "llvm/ADT/SetVector.h"
21755dc07dSRiver Riddle #include "llvm/ADT/SmallPtrSet.h"
225c16564bSNicolas Vasilache 
235c16564bSNicolas Vasilache ///
2469d9e990SChris Lattner /// Implements Analysis functions specific to slicing in Function.
255c16564bSNicolas Vasilache ///
265c16564bSNicolas Vasilache 
275c16564bSNicolas Vasilache using namespace mlir;
285c16564bSNicolas Vasilache 
29641b12e9SMahesh Ravishankar static void
30641b12e9SMahesh Ravishankar getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
3106927848SMehdi Amini                     const SliceOptions::TransitiveFilter &filter = nullptr) {
32d01ea0edSNicolas Vasilache   if (!op)
335c16564bSNicolas Vasilache     return;
345c16564bSNicolas Vasilache 
355c16564bSNicolas Vasilache   // Evaluate whether we should keep this use.
365c16564bSNicolas Vasilache   // This is useful in particular to implement scoping; i.e. return the
37258dae5dSNicolas Vasilache   // transitive forwardSlice in the current scope.
38d01ea0edSNicolas Vasilache   if (filter && !filter(op))
395c16564bSNicolas Vasilache     return;
405c16564bSNicolas Vasilache 
41d01ea0edSNicolas Vasilache   for (Region &region : op->getRegions())
42d01ea0edSNicolas Vasilache     for (Block &block : region)
43d01ea0edSNicolas Vasilache       for (Operation &blockOp : block)
44d01ea0edSNicolas Vasilache         if (forwardSlice->count(&blockOp) == 0)
45d01ea0edSNicolas Vasilache           getForwardSliceImpl(&blockOp, forwardSlice, filter);
462f23270aSThomas Raoux   for (Value result : op->getResults()) {
472f23270aSThomas Raoux     for (Operation *userOp : result.getUsers())
482f23270aSThomas Raoux       if (forwardSlice->count(userOp) == 0)
492f23270aSThomas Raoux         getForwardSliceImpl(userOp, forwardSlice, filter);
505c16564bSNicolas Vasilache   }
515c16564bSNicolas Vasilache 
529c085406SRiver Riddle   forwardSlice->insert(op);
535c16564bSNicolas Vasilache }
54c3b0c6a0SNicolas Vasilache 
559c085406SRiver Riddle void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
5606927848SMehdi Amini                            const ForwardSliceOptions &options) {
57641b12e9SMahesh Ravishankar   getForwardSliceImpl(op, forwardSlice, options.filter);
58641b12e9SMahesh Ravishankar   if (!options.inclusive) {
599c085406SRiver Riddle     // Don't insert the top level operation, we just queried on it and don't
60c3b0c6a0SNicolas Vasilache     // want it in the results.
619c085406SRiver Riddle     forwardSlice->remove(op);
626d2501bfSNicolas Vasilache   }
63c3b0c6a0SNicolas Vasilache 
64c3b0c6a0SNicolas Vasilache   // Reverse to get back the actual topological order.
65c3b0c6a0SNicolas Vasilache   // std::reverse does not work out of the box on SetVector and I want an
66c3b0c6a0SNicolas Vasilache   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
671b162fabSFangrui Song   SmallVector<Operation *, 0> v(forwardSlice->takeVector());
68c3b0c6a0SNicolas Vasilache   forwardSlice->insert(v.rbegin(), v.rend());
695c16564bSNicolas Vasilache }
705c16564bSNicolas Vasilache 
71d01ea0edSNicolas Vasilache void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
7206927848SMehdi Amini                            const SliceOptions &options) {
73d01ea0edSNicolas Vasilache   for (Operation *user : root.getUsers())
74641b12e9SMahesh Ravishankar     getForwardSliceImpl(user, forwardSlice, options.filter);
75d01ea0edSNicolas Vasilache 
76d01ea0edSNicolas Vasilache   // Reverse to get back the actual topological order.
77d01ea0edSNicolas Vasilache   // std::reverse does not work out of the box on SetVector and I want an
78d01ea0edSNicolas Vasilache   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
791b162fabSFangrui Song   SmallVector<Operation *, 0> v(forwardSlice->takeVector());
80d01ea0edSNicolas Vasilache   forwardSlice->insert(v.rbegin(), v.rend());
81d01ea0edSNicolas Vasilache }
82d01ea0edSNicolas Vasilache 
839c085406SRiver Riddle static void getBackwardSliceImpl(Operation *op,
849c085406SRiver Riddle                                  SetVector<Operation *> *backwardSlice,
8506927848SMehdi Amini                                  const BackwardSliceOptions &options) {
86d01ea0edSNicolas Vasilache   if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
875c16564bSNicolas Vasilache     return;
885c16564bSNicolas Vasilache 
895c16564bSNicolas Vasilache   // Evaluate whether we should keep this def.
905c16564bSNicolas Vasilache   // This is useful in particular to implement scoping; i.e. return the
91d01ea0edSNicolas Vasilache   // transitive backwardSlice in the current scope.
92641b12e9SMahesh Ravishankar   if (options.filter && !options.filter(op))
935c16564bSNicolas Vasilache     return;
945c16564bSNicolas Vasilache 
95*d97bc388SIan Wood   auto processValue = [&](Value value) {
96*d97bc388SIan Wood     if (auto *definingOp = value.getDefiningOp()) {
97d01ea0edSNicolas Vasilache       if (backwardSlice->count(definingOp) == 0)
98641b12e9SMahesh Ravishankar         getBackwardSliceImpl(definingOp, backwardSlice, options);
99*d97bc388SIan Wood     } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
100641b12e9SMahesh Ravishankar       if (options.omitBlockArguments)
101*d97bc388SIan Wood         return;
102641b12e9SMahesh Ravishankar 
103d01ea0edSNicolas Vasilache       Block *block = blockArg.getOwner();
104d01ea0edSNicolas Vasilache       Operation *parentOp = block->getParentOp();
105d01ea0edSNicolas Vasilache       // TODO: determine whether we want to recurse backward into the other
106d01ea0edSNicolas Vasilache       // blocks of parentOp, which are not technically backward unless they flow
107d01ea0edSNicolas Vasilache       // into us. For now, just bail.
1088c33639aSHanhan Wang       if (parentOp && backwardSlice->count(parentOp) == 0) {
109d01ea0edSNicolas Vasilache         assert(parentOp->getNumRegions() == 1 &&
110d01ea0edSNicolas Vasilache                parentOp->getRegion(0).getBlocks().size() == 1);
111641b12e9SMahesh Ravishankar         getBackwardSliceImpl(parentOp, backwardSlice, options);
1128c33639aSHanhan Wang       }
113d01ea0edSNicolas Vasilache     } else {
114d01ea0edSNicolas Vasilache       llvm_unreachable("No definingOp and not a block argument.");
1155c16564bSNicolas Vasilache     }
116*d97bc388SIan Wood   };
117*d97bc388SIan Wood 
118*d97bc388SIan Wood   if (!options.omitUsesFromAbove) {
119*d97bc388SIan Wood     llvm::for_each(op->getRegions(), [&](Region &region) {
120*d97bc388SIan Wood       // Walk this region recursively to collect the regions that descend from
121*d97bc388SIan Wood       // this op's nested regions (inclusive).
122*d97bc388SIan Wood       SmallPtrSet<Region *, 4> descendents;
123*d97bc388SIan Wood       region.walk(
124*d97bc388SIan Wood           [&](Region *childRegion) { descendents.insert(childRegion); });
125*d97bc388SIan Wood       region.walk([&](Operation *op) {
126*d97bc388SIan Wood         for (OpOperand &operand : op->getOpOperands()) {
127*d97bc388SIan Wood           if (!descendents.contains(operand.get().getParentRegion()))
128*d97bc388SIan Wood             processValue(operand.get());
1295c16564bSNicolas Vasilache         }
130*d97bc388SIan Wood       });
131*d97bc388SIan Wood     });
132*d97bc388SIan Wood   }
133*d97bc388SIan Wood   llvm::for_each(op->getOperands(), processValue);
1345c16564bSNicolas Vasilache 
1359c085406SRiver Riddle   backwardSlice->insert(op);
1365c16564bSNicolas Vasilache }
137c3b0c6a0SNicolas Vasilache 
1389c085406SRiver Riddle void mlir::getBackwardSlice(Operation *op,
1399c085406SRiver Riddle                             SetVector<Operation *> *backwardSlice,
14006927848SMehdi Amini                             const BackwardSliceOptions &options) {
141641b12e9SMahesh Ravishankar   getBackwardSliceImpl(op, backwardSlice, options);
142c3b0c6a0SNicolas Vasilache 
143641b12e9SMahesh Ravishankar   if (!options.inclusive) {
1449c085406SRiver Riddle     // Don't insert the top level operation, we just queried on it and don't
145c3b0c6a0SNicolas Vasilache     // want it in the results.
1469c085406SRiver Riddle     backwardSlice->remove(op);
1475c16564bSNicolas Vasilache   }
1486d2501bfSNicolas Vasilache }
1495c16564bSNicolas Vasilache 
150d01ea0edSNicolas Vasilache void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
15106927848SMehdi Amini                             const BackwardSliceOptions &options) {
152d01ea0edSNicolas Vasilache   if (Operation *definingOp = root.getDefiningOp()) {
153641b12e9SMahesh Ravishankar     getBackwardSlice(definingOp, backwardSlice, options);
154d01ea0edSNicolas Vasilache     return;
155d01ea0edSNicolas Vasilache   }
1565550c821STres Popp   Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
157641b12e9SMahesh Ravishankar   getBackwardSlice(bbAargOwner, backwardSlice, options);
158d01ea0edSNicolas Vasilache }
159d01ea0edSNicolas Vasilache 
16006927848SMehdi Amini SetVector<Operation *>
16106927848SMehdi Amini mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
16206927848SMehdi Amini                const ForwardSliceOptions &forwardSliceOptions) {
1639c085406SRiver Riddle   SetVector<Operation *> slice;
1649c085406SRiver Riddle   slice.insert(op);
1655c16564bSNicolas Vasilache 
166258dae5dSNicolas Vasilache   unsigned currentIndex = 0;
1679c085406SRiver Riddle   SetVector<Operation *> backwardSlice;
1689c085406SRiver Riddle   SetVector<Operation *> forwardSlice;
169258dae5dSNicolas Vasilache   while (currentIndex != slice.size()) {
1706953cf65SNicolas Vasilache     auto *currentOp = (slice)[currentIndex];
1716953cf65SNicolas Vasilache     // Compute and insert the backwardSlice starting from currentOp.
172258dae5dSNicolas Vasilache     backwardSlice.clear();
173641b12e9SMahesh Ravishankar     getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
174258dae5dSNicolas Vasilache     slice.insert(backwardSlice.begin(), backwardSlice.end());
1755c16564bSNicolas Vasilache 
1766953cf65SNicolas Vasilache     // Compute and insert the forwardSlice starting from currentOp.
177258dae5dSNicolas Vasilache     forwardSlice.clear();
178641b12e9SMahesh Ravishankar     getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
179258dae5dSNicolas Vasilache     slice.insert(forwardSlice.begin(), forwardSlice.end());
1805c16564bSNicolas Vasilache     ++currentIndex;
1815c16564bSNicolas Vasilache   }
182258dae5dSNicolas Vasilache   return topologicalSort(slice);
1835c16564bSNicolas Vasilache }
1845c16564bSNicolas Vasilache 
185755dc07dSRiver Riddle /// Returns true if `value` (transitively) depends on iteration-carried values
186755dc07dSRiver Riddle /// of the given `ancestorOp`.
187755dc07dSRiver Riddle static bool dependsOnCarriedVals(Value value,
188755dc07dSRiver Riddle                                  ArrayRef<BlockArgument> iterCarriedArgs,
189755dc07dSRiver Riddle                                  Operation *ancestorOp) {
190755dc07dSRiver Riddle   // Compute the backward slice of the value.
191755dc07dSRiver Riddle   SetVector<Operation *> slice;
192641b12e9SMahesh Ravishankar   BackwardSliceOptions sliceOptions;
193641b12e9SMahesh Ravishankar   sliceOptions.filter = [&](Operation *op) {
194641b12e9SMahesh Ravishankar     return !ancestorOp->isAncestor(op);
195641b12e9SMahesh Ravishankar   };
196641b12e9SMahesh Ravishankar   getBackwardSlice(value, &slice, sliceOptions);
197755dc07dSRiver Riddle 
198755dc07dSRiver Riddle   // Check that none of the operands of the operations in the backward slice are
199755dc07dSRiver Riddle   // loop iteration arguments, and neither is the value itself.
200755dc07dSRiver Riddle   SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
201755dc07dSRiver Riddle                                           iterCarriedArgs.end());
202755dc07dSRiver Riddle   if (iterCarriedValSet.contains(value))
203755dc07dSRiver Riddle     return true;
204755dc07dSRiver Riddle 
205755dc07dSRiver Riddle   for (Operation *op : slice)
206755dc07dSRiver Riddle     for (Value operand : op->getOperands())
207755dc07dSRiver Riddle       if (iterCarriedValSet.contains(operand))
208755dc07dSRiver Riddle         return true;
209755dc07dSRiver Riddle 
210755dc07dSRiver Riddle   return false;
211755dc07dSRiver Riddle }
212755dc07dSRiver Riddle 
213755dc07dSRiver Riddle /// Utility to match a generic reduction given a list of iteration-carried
214755dc07dSRiver Riddle /// arguments, `iterCarriedArgs` and the position of the potential reduction
215755dc07dSRiver Riddle /// argument within the list, `redPos`. If a reduction is matched, returns the
216755dc07dSRiver Riddle /// reduced value and the topologically-sorted list of combiner operations
217755dc07dSRiver Riddle /// involved in the reduction. Otherwise, returns a null value.
218755dc07dSRiver Riddle ///
219755dc07dSRiver Riddle /// The matching algorithm relies on the following invariants, which are subject
220755dc07dSRiver Riddle /// to change:
221755dc07dSRiver Riddle ///  1. The first combiner operation must be a binary operation with the
222755dc07dSRiver Riddle ///     iteration-carried value and the reduced value as operands.
223755dc07dSRiver Riddle ///  2. The iteration-carried value and combiner operations must be side
224755dc07dSRiver Riddle ///     effect-free, have single result and a single use.
225755dc07dSRiver Riddle ///  3. Combiner operations must be immediately nested in the region op
226755dc07dSRiver Riddle ///     performing the reduction.
227755dc07dSRiver Riddle ///  4. Reduction def-use chain must end in a terminator op that yields the
228755dc07dSRiver Riddle ///     next iteration/output values in the same order as the iteration-carried
229755dc07dSRiver Riddle ///     values in `iterCarriedArgs`.
230755dc07dSRiver Riddle ///  5. `iterCarriedArgs` must contain all the iteration-carried/output values
231755dc07dSRiver Riddle ///     of the region op performing the reduction.
232755dc07dSRiver Riddle ///
233755dc07dSRiver Riddle /// This utility is generic enough to detect reductions involving multiple
234755dc07dSRiver Riddle /// combiner operations (disabled for now) across multiple dialects, including
235755dc07dSRiver Riddle /// Linalg, Affine and SCF. For the sake of genericity, it does not return
236755dc07dSRiver Riddle /// specific enum values for the combiner operations since its goal is also
237755dc07dSRiver Riddle /// matching reductions without pre-defined semantics in core MLIR. It's up to
238755dc07dSRiver Riddle /// each client to make sense out of the list of combiner operations. It's also
239755dc07dSRiver Riddle /// up to each client to check for additional invariants on the expected
240755dc07dSRiver Riddle /// reductions not covered by this generic matching.
241755dc07dSRiver Riddle Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
242755dc07dSRiver Riddle                            unsigned redPos,
243755dc07dSRiver Riddle                            SmallVectorImpl<Operation *> &combinerOps) {
244755dc07dSRiver Riddle   assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
245755dc07dSRiver Riddle 
246755dc07dSRiver Riddle   BlockArgument redCarriedVal = iterCarriedArgs[redPos];
247755dc07dSRiver Riddle   if (!redCarriedVal.hasOneUse())
248755dc07dSRiver Riddle     return nullptr;
249755dc07dSRiver Riddle 
250755dc07dSRiver Riddle   // For now, the first combiner op must be a binary op.
251755dc07dSRiver Riddle   Operation *combinerOp = *redCarriedVal.getUsers().begin();
252755dc07dSRiver Riddle   if (combinerOp->getNumOperands() != 2)
253755dc07dSRiver Riddle     return nullptr;
254755dc07dSRiver Riddle   Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
255755dc07dSRiver Riddle                          ? combinerOp->getOperand(1)
256755dc07dSRiver Riddle                          : combinerOp->getOperand(0);
257755dc07dSRiver Riddle 
258755dc07dSRiver Riddle   Operation *redRegionOp =
259755dc07dSRiver Riddle       iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
260755dc07dSRiver Riddle   if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
261755dc07dSRiver Riddle     return nullptr;
262755dc07dSRiver Riddle 
263755dc07dSRiver Riddle   // Traverse the def-use chain starting from the first combiner op until a
264755dc07dSRiver Riddle   // terminator is found. Gather all the combiner ops along the way in
265755dc07dSRiver Riddle   // topological order.
266755dc07dSRiver Riddle   while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
267fc367dfaSMahesh Ravishankar     if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
268fc367dfaSMahesh Ravishankar         !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
269755dc07dSRiver Riddle       return nullptr;
270755dc07dSRiver Riddle 
271755dc07dSRiver Riddle     combinerOps.push_back(combinerOp);
272755dc07dSRiver Riddle     combinerOp = *combinerOp->getUsers().begin();
273755dc07dSRiver Riddle   }
274755dc07dSRiver Riddle 
275755dc07dSRiver Riddle   // Limit matching to single combiner op until we can properly test reductions
276755dc07dSRiver Riddle   // involving multiple combiners.
277755dc07dSRiver Riddle   if (combinerOps.size() != 1)
278755dc07dSRiver Riddle     return nullptr;
279755dc07dSRiver Riddle 
280755dc07dSRiver Riddle   // Check that the yielded value is in the same position as in
281755dc07dSRiver Riddle   // `iterCarriedArgs`.
282755dc07dSRiver Riddle   Operation *terminatorOp = combinerOp;
283755dc07dSRiver Riddle   if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
284755dc07dSRiver Riddle     return nullptr;
285755dc07dSRiver Riddle 
286755dc07dSRiver Riddle   return reducedVal;
287755dc07dSRiver Riddle }
288