xref: /llvm-project/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp (revision 938cdd60d4938e32a7f4f1620e3d9c11aabc4af5)
1 //===- CheckUses.cpp - Expensive transform value validity checks ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a pass that performs expensive opt-in checks for Transform
10 // dialect values being potentially used after they have been consumed.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Transform/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17 #include "mlir/Interfaces/SideEffectInterfaces.h"
18 #include "mlir/Pass/Pass.h"
19 #include "llvm/ADT/SetOperations.h"
20 
21 namespace mlir {
22 namespace transform {
23 #define GEN_PASS_DEF_CHECKUSESPASS
24 #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
25 } // namespace transform
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// Returns a reference to a cached set of blocks that are reachable from the
33 /// given block via edges computed by the `getNextNodes` function. For example,
34 /// if `getNextNodes` returns successors of a block, this will return the set of
35 /// reachable blocks; if it returns predecessors of a block, this will return
36 /// the set of blocks from which the given block can be reached. The block is
37 /// considered reachable form itself only if there is a cycle.
38 template <typename FnTy>
39 const llvm::SmallPtrSet<Block *, 4> &
40 getReachableImpl(Block *block, FnTy getNextNodes,
41                  DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> &cache) {
42   auto [it, inserted] = cache.try_emplace(block);
43   if (!inserted)
44     return it->getSecond();
45 
46   llvm::SmallPtrSet<Block *, 4> &reachable = it->second;
47   SmallVector<Block *> worklist;
48   worklist.push_back(block);
49   while (!worklist.empty()) {
50     Block *current = worklist.pop_back_val();
51     for (Block *predecessor : getNextNodes(current)) {
52       // The block is reachable from its transitive predecessors. Only add
53       // them to the worklist if they weren't already visited.
54       if (reachable.insert(predecessor).second)
55         worklist.push_back(predecessor);
56     }
57   }
58   return reachable;
59 }
60 
61 /// An analysis that identifies whether a value allocated by a Transform op may
62 /// be used by another such op after it may have been freed by a third op on
63 /// some control flow path. This is conceptually similar to a data flow
64 /// analysis, but relies on side effects related to particular values that
65 /// currently cannot be modeled by the MLIR data flow analysis framework (also,
66 /// the lattice element would be rather expensive as it would need to include
67 /// live and/or freed values for each operation).
68 ///
69 /// This analysis is conservatively pessimisic: it will consider that a value
70 /// may be freed if it is freed on any possible control flow path between its
71 /// allocation and a relevant use, even if the control never actually flows
72 /// through the operation that frees the value. It also does not differentiate
73 /// between may- (freed on at least one control flow path) and must-free (freed
74 /// on all possible control flow paths) because it would require expensive graph
75 /// algorithms.
76 ///
77 /// It is intended as an additional non-blocking verification or debugging aid
78 /// for ops in the Transform dialect. It leverages the requirement for Transform
79 /// dialect ops to implement the MemoryEffectsOpInterface, and expects the
80 /// values in the Transform IR to have an allocation effect on the
81 /// TransformMappingResource when defined.
82 class TransformOpMemFreeAnalysis {
83 public:
84   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformOpMemFreeAnalysis)
85 
86   /// Computes the analysis for Transform ops nested in the given operation.
87   explicit TransformOpMemFreeAnalysis(Operation *root) {
88     root->walk([&](Operation *op) {
89       if (isa<transform::TransformOpInterface>(op)) {
90         collectFreedValues(op);
91         return WalkResult::skip();
92       }
93       return WalkResult::advance();
94     });
95   }
96 
97   /// A list of operations that may be deleting a value. Non-empty list
98   /// contextually converts to boolean "true" value.
99   class PotentialDeleters {
100   public:
101     /// Creates an empty list that corresponds to the value being live.
102     static PotentialDeleters live() { return PotentialDeleters({}); }
103 
104     /// Creates a list from the operations that may be deleting the value.
105     static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
106       return PotentialDeleters(deleters);
107     }
108 
109     /// Converts to "true" if there are operations that may be deleting the
110     /// value.
111     explicit operator bool() const { return !deleters.empty(); }
112 
113     /// Concatenates the lists of operations that may be deleting the value. The
114     /// value is known to be live if the reuslting list is still empty.
115     PotentialDeleters &operator|=(const PotentialDeleters &other) {
116       llvm::append_range(deleters, other.deleters);
117       return *this;
118     }
119 
120     /// Returns the list of ops that may be deleting the value.
121     ArrayRef<Operation *> getOps() const { return deleters; }
122 
123   private:
124     /// Constructs the list from the given operations.
125     explicit PotentialDeleters(ArrayRef<Operation *> ops) {
126       llvm::append_range(deleters, ops);
127     }
128 
129     /// The list of operations that may be deleting the value.
130     SmallVector<Operation *> deleters;
131   };
132 
133   /// Returns the list of operations that may be deleting the operand value on
134   /// any control flow path between the definition of the value and its use as
135   /// the given operand. For the purposes of this analysis, the value is
136   /// considered to be allocated at its definition point and never re-allocated.
137   PotentialDeleters isUseLive(OpOperand &operand) {
138     const llvm::SmallPtrSet<Operation *, 2> &deleters = freedBy[operand.get()];
139     if (deleters.empty())
140       return live();
141 
142 #ifndef NDEBUG
143     // Check that the definition point actually allocates the value. If the
144     // definition is a block argument, it may be just forwarding the operand of
145     // the parent op without doing a new allocation, allow that. We currently
146     // don't have the capability to analyze region-based control flow here.
147     //
148     // TODO: when this ported to the dataflow analysis infra, we should have
149     // proper support for region-based control flow.
150     Operation *valueSource =
151         isa<OpResult>(operand.get())
152             ? operand.get().getDefiningOp()
153             : operand.get().getParentBlock()->getParentOp();
154     auto iface = cast<MemoryEffectOpInterface>(valueSource);
155     SmallVector<MemoryEffects::EffectInstance> instances;
156     iface.getEffectsOnResource(transform::TransformMappingResource::get(),
157                                instances);
158     assert((isa<BlockArgument>(operand.get()) ||
159             hasEffect<MemoryEffects::Allocate>(instances, operand.get())) &&
160            "expected the op defining the value to have an allocation effect "
161            "on it");
162 #endif
163 
164     // Collect ancestors of the use operation.
165     Block *defBlock = operand.get().getParentBlock();
166     SmallVector<Operation *> ancestors;
167     Operation *ancestor = operand.getOwner();
168     do {
169       ancestors.push_back(ancestor);
170       if (ancestor->getParentRegion() == defBlock->getParent())
171         break;
172       ancestor = ancestor->getParentOp();
173     } while (true);
174     std::reverse(ancestors.begin(), ancestors.end());
175 
176     // Consider the control flow from the definition point of the value to its
177     // use point. If the use is located in some nested region, consider the path
178     // from the entry block of the region to the use.
179     for (Operation *ancestor : ancestors) {
180       // The block should be considered partially if it is the block that
181       // contains the definition (allocation) of the value being used, and the
182       // value is defined in the middle of the block, i.e., is not a block
183       // argument.
184       bool isOutermost = ancestor == ancestors.front();
185       bool isFromBlockPartial = isOutermost && isa<OpResult>(operand.get());
186 
187       // Check if the value may be freed by operations between its definition
188       // (allocation) point in its block and the terminator of the block or the
189       // ancestor of the use if it is located in the same block. This is only
190       // done for partial blocks here, full blocks will be considered below
191       // similarly to other blocks.
192       if (isFromBlockPartial) {
193         bool defUseSameBlock = ancestor->getBlock() == defBlock;
194         // Consider all ops from the def to its block terminator, except the
195         // when the use is in the same block, in which case only consider the
196         // ops until the user.
197         if (PotentialDeleters potentialDeleters = isFreedInBlockAfter(
198                 operand.get().getDefiningOp(), operand.get(),
199                 defUseSameBlock ? ancestor : nullptr))
200           return potentialDeleters;
201       }
202 
203       // Check if the value may be freed by opeations preceding the ancestor in
204       // its block. Skip the check for partial blocks that contain both the
205       // definition and the use point, as this has been already checked above.
206       if (!isFromBlockPartial || ancestor->getBlock() != defBlock) {
207         if (PotentialDeleters potentialDeleters =
208                 isFreedInBlockBefore(ancestor, operand.get()))
209           return potentialDeleters;
210       }
211 
212       // Check if the value may be freed by operations in any of the blocks
213       // between the definition point (in the outermost region) or the entry
214       // block of the region (in other regions) and the operand or its ancestor
215       // in the region. This includes the entire "form" block if (1) the block
216       // has not been considered as partial above and (2) the block can be
217       // reached again through some control-flow loop. This includes the entire
218       // "to" block if it can be reached form itself through some control-flow
219       // cycle, regardless of whether it has been visited before.
220       Block *ancestorBlock = ancestor->getBlock();
221       Block *from =
222           isOutermost ? defBlock : &ancestorBlock->getParent()->front();
223       if (PotentialDeleters potentialDeleters =
224               isMaybeFreedOnPaths(from, ancestorBlock, operand.get(),
225                                   /*alwaysIncludeFrom=*/!isFromBlockPartial))
226         return potentialDeleters;
227     }
228     return live();
229   }
230 
231 private:
232   /// Make PotentialDeleters constructors available with shorter names.
233   static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
234     return PotentialDeleters::maybeFreed(deleters);
235   }
236   static PotentialDeleters live() { return PotentialDeleters::live(); }
237 
238   /// Returns the list of operations that may be deleting the given value betwen
239   /// the first and last operations, non-inclusive. `getNext` indicates the
240   /// direction of the traversal.
241   PotentialDeleters
242   isFreedBetween(Value value, Operation *first, Operation *last,
243                  llvm::function_ref<Operation *(Operation *)> getNext) const {
244     auto it = freedBy.find(value);
245     if (it == freedBy.end())
246       return live();
247     const llvm::SmallPtrSet<Operation *, 2> &deleters = it->getSecond();
248     for (Operation *op = getNext(first); op != last; op = getNext(op)) {
249       if (deleters.contains(op))
250         return maybeFreed(op);
251     }
252     return live();
253   }
254 
255   /// Returns the list of operations that may be deleting the given value
256   /// between `root` and `before` values. `root` is expected to be in the same
257   /// block as `before` and precede it. If `before` is null, consider all
258   /// operations until the end of the block including the terminator.
259   PotentialDeleters isFreedInBlockAfter(Operation *root, Value value,
260                                         Operation *before = nullptr) const {
261     return isFreedBetween(value, root, before,
262                           [](Operation *op) { return op->getNextNode(); });
263   }
264 
265   /// Returns the list of operations that may be deleting the given value
266   /// between the entry of the block and the `root` operation.
267   PotentialDeleters isFreedInBlockBefore(Operation *root, Value value) const {
268     return isFreedBetween(value, root, nullptr,
269                           [](Operation *op) { return op->getPrevNode(); });
270   }
271 
272   /// Returns the list of operations that may be deleting the given value on
273   /// any of the control flow paths between the "form" and the "to" block. The
274   /// operations from any block visited on any control flow path are
275   /// consdiered. The "from" block is considered if there is a control flow
276   /// cycle going through it, i.e., if there is a possibility that all
277   /// operations in this block are visited or if the `alwaysIncludeFrom` flag is
278   /// set. The "to" block is considered only if there is a control flow cycle
279   /// going through it.
280   PotentialDeleters isMaybeFreedOnPaths(Block *from, Block *to, Value value,
281                                         bool alwaysIncludeFrom) {
282     // Find all blocks that lie on any path between "from" and "to", i.e., the
283     // intersection of blocks reachable from "from" and blocks from which "to"
284     // is rechable.
285     const llvm::SmallPtrSet<Block *, 4> &sources = getReachableFrom(to);
286     if (!sources.contains(from))
287       return live();
288 
289     llvm::SmallPtrSet<Block *, 4> reachable(getReachable(from));
290     llvm::set_intersect(reachable, sources);
291 
292     // If requested, include the "from" block that may not be present in the set
293     // of visited blocks when there is no cycle going through it.
294     if (alwaysIncludeFrom)
295       reachable.insert(from);
296 
297     // Join potential deleters from all blocks as we don't know here which of
298     // the paths through the control flow is taken.
299     PotentialDeleters potentialDeleters = live();
300     for (Block *block : reachable) {
301       for (Operation &op : *block) {
302         if (freedBy[value].count(&op))
303           potentialDeleters |= maybeFreed(&op);
304       }
305     }
306     return potentialDeleters;
307   }
308 
309   /// Popualtes `reachable` with the set of blocks that are rechable from the
310   /// given block. A block is considered reachable from itself if there is a
311   /// cycle in the control-flow graph that invovles the block.
312   const llvm::SmallPtrSet<Block *, 4> &getReachable(Block *block) {
313     return getReachableImpl(
314         block, [](Block *b) { return b->getSuccessors(); }, reachableCache);
315   }
316 
317   /// Populates `sources` with the set of blocks from which the given block is
318   /// reachable.
319   const llvm::SmallPtrSet<Block *, 4> &getReachableFrom(Block *block) {
320     return getReachableImpl(
321         block, [](Block *b) { return b->getPredecessors(); },
322         reachableFromCache);
323   }
324 
325   /// Returns true of `instances` contains an effect of `EffectTy` on `value`.
326   template <typename EffectTy>
327   static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances,
328                         Value value) {
329     return llvm::any_of(instances,
330                         [&](const MemoryEffects::EffectInstance &instance) {
331                           return instance.getValue() == value &&
332                                  isa<EffectTy>(instance.getEffect());
333                         });
334   }
335 
336   /// Records the values that are being freed by an operation or any of its
337   /// children in `freedBy`.
338   void collectFreedValues(Operation *root) {
339     SmallVector<MemoryEffects::EffectInstance> instances;
340     root->walk([&](Operation *child) {
341       if (isa<transform::PatternDescriptorOpInterface>(child))
342         return;
343       // TODO: extend this to conservatively handle operations with undeclared
344       // side effects as maybe freeing the operands.
345       auto iface = cast<MemoryEffectOpInterface>(child);
346       instances.clear();
347       iface.getEffectsOnResource(transform::TransformMappingResource::get(),
348                                  instances);
349       for (Value operand : child->getOperands()) {
350         if (hasEffect<MemoryEffects::Free>(instances, operand)) {
351           // All parents of the operation that frees a value should be
352           // considered as potentially freeing the value as well.
353           //
354           // TODO: differentiate between must-free/may-free as well as between
355           // this op having the effect and children having the effect. This may
356           // require some analysis of all control flow paths through the nested
357           // regions as well as a mechanism to separate proper side effects from
358           // those obtained by nesting.
359           Operation *parent = child;
360           do {
361             freedBy[operand].insert(parent);
362             if (parent == root)
363               break;
364             parent = parent->getParentOp();
365           } while (true);
366         }
367       }
368     });
369   }
370 
371   /// The mapping from a value to operations that have a Free memory effect on
372   /// the TransformMappingResource and associated with this value, or to
373   /// Transform operations transitively containing such operations.
374   DenseMap<Value, llvm::SmallPtrSet<Operation *, 2>> freedBy;
375 
376   /// Caches for sets of reachable blocks.
377   DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableCache;
378   DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableFromCache;
379 };
380 
381 //// A simple pass that warns about any use of a value by a transform operation
382 // that may be using the value after it has been freed.
383 class CheckUsesPass : public transform::impl::CheckUsesPassBase<CheckUsesPass> {
384 public:
385   void runOnOperation() override {
386     auto &analysis = getAnalysis<TransformOpMemFreeAnalysis>();
387 
388     getOperation()->walk([&](Operation *child) {
389       for (OpOperand &operand : child->getOpOperands()) {
390         TransformOpMemFreeAnalysis::PotentialDeleters deleters =
391             analysis.isUseLive(operand);
392         if (!deleters)
393           continue;
394 
395         InFlightDiagnostic diag = child->emitWarning()
396                                   << "operand #" << operand.getOperandNumber()
397                                   << " may be used after free";
398         diag.attachNote(operand.get().getLoc()) << "allocated here";
399         for (Operation *d : deleters.getOps()) {
400           diag.attachNote(d->getLoc()) << "freed here";
401         }
402       }
403     });
404   }
405 };
406 
407 } // namespace
408 
409