xref: /llvm-project/mlir/lib/Transforms/RemoveDeadValues.cpp (revision aa3c31a86f39552d11f0d5bae8b50541d73aa442)
1 //===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The goal of this pass is optimization (reducing runtime) by removing
10 // unnecessary instructions. Unlike other passes that rely on local information
11 // gathered from patterns to accomplish optimization, this pass uses a full
12 // analysis of the IR, specifically, liveness analysis, and is thus more
13 // powerful.
14 //
15 // Currently, this pass performs the following optimizations:
16 // (A) Removes function arguments that are not live,
17 // (B) Removes function return values that are not live across all callers of
18 // the function,
19 // (C) Removes unneccesary operands, results, region arguments, and region
20 // terminator operands of region branch ops, and,
21 // (D) Removes simple and region branch ops that have all non-live results and
22 // don't affect memory in any way,
23 //
24 // iff
25 //
26 // the IR doesn't have any non-function symbol ops, non-call symbol user ops and
27 // branch ops.
28 //
29 // Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
30 // region branch op, branch op, region branch terminator op, or return-like.
31 //
32 //===----------------------------------------------------------------------===//
33 
34 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
35 #include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
36 #include "mlir/IR/Attributes.h"
37 #include "mlir/IR/Builders.h"
38 #include "mlir/IR/BuiltinAttributes.h"
39 #include "mlir/IR/Dialect.h"
40 #include "mlir/IR/IRMapping.h"
41 #include "mlir/IR/OperationSupport.h"
42 #include "mlir/IR/SymbolTable.h"
43 #include "mlir/IR/Value.h"
44 #include "mlir/IR/ValueRange.h"
45 #include "mlir/IR/Visitors.h"
46 #include "mlir/Interfaces/CallInterfaces.h"
47 #include "mlir/Interfaces/ControlFlowInterfaces.h"
48 #include "mlir/Interfaces/FunctionInterfaces.h"
49 #include "mlir/Interfaces/SideEffectInterfaces.h"
50 #include "mlir/Pass/Pass.h"
51 #include "mlir/Support/LLVM.h"
52 #include "mlir/Transforms/FoldUtils.h"
53 #include "mlir/Transforms/Passes.h"
54 #include "llvm/ADT/STLExtras.h"
55 #include <cassert>
56 #include <cstddef>
57 #include <memory>
58 #include <optional>
59 #include <vector>
60 
61 namespace mlir {
62 #define GEN_PASS_DEF_REMOVEDEADVALUES
63 #include "mlir/Transforms/Passes.h.inc"
64 } // namespace mlir
65 
66 using namespace mlir;
67 using namespace mlir::dataflow;
68 
69 //===----------------------------------------------------------------------===//
70 // RemoveDeadValues Pass
71 //===----------------------------------------------------------------------===//
72 
73 namespace {
74 
75 // Set of structures below to be filled with operations and arguments to erase.
76 // This is done to separate analysis and tree modification phases,
77 // otherwise analysis is operating on half-deleted tree which is incorrect.
78 
79 struct FunctionToCleanUp {
80   FunctionOpInterface funcOp;
81   BitVector nonLiveArgs;
82   BitVector nonLiveRets;
83 };
84 
85 struct OperationToCleanup {
86   Operation *op;
87   BitVector nonLive;
88 };
89 
90 struct BlockArgsToCleanup {
91   Block *b;
92   BitVector nonLiveArgs;
93 };
94 
95 struct SuccessorOperandsToCleanup {
96   BranchOpInterface branch;
97   unsigned successorIndex;
98   BitVector nonLiveOperands;
99 };
100 
101 struct RDVFinalCleanupList {
102   SmallVector<Operation *> operations;
103   SmallVector<Value> values;
104   SmallVector<FunctionToCleanUp> functions;
105   SmallVector<OperationToCleanup> operands;
106   SmallVector<OperationToCleanup> results;
107   SmallVector<BlockArgsToCleanup> blocks;
108   SmallVector<SuccessorOperandsToCleanup> successorOperands;
109 };
110 
111 // Some helper functions...
112 
113 /// Return true iff at least one value in `values` is live, given the liveness
114 /// information in `la`.
115 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
116                     RunLivenessAnalysis &la) {
117   for (Value value : values) {
118     if (nonLiveSet.contains(value))
119       continue;
120 
121     const Liveness *liveness = la.getLiveness(value);
122     if (!liveness || liveness->isLive)
123       return true;
124   }
125   return false;
126 }
127 
128 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
129 /// i-th value in `values` is live, given the liveness information in `la`.
130 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
131                            RunLivenessAnalysis &la) {
132   BitVector lives(values.size(), true);
133 
134   for (auto [index, value] : llvm::enumerate(values)) {
135     if (nonLiveSet.contains(value)) {
136       lives.reset(index);
137       continue;
138     }
139 
140     const Liveness *liveness = la.getLiveness(value);
141     // It is important to note that when `liveness` is null, we can't tell if
142     // `value` is live or not. So, the safe option is to consider it live. Also,
143     // the execution of this pass might create new SSA values when erasing some
144     // of the results of an op and we know that these new values are live
145     // (because they weren't erased) and also their liveness is null because
146     // liveness analysis ran before their creation.
147     if (liveness && !liveness->isLive)
148       lives.reset(index);
149   }
150 
151   return lives;
152 }
153 
154 /// Collects values marked as "non-live" in the provided range and inserts them
155 /// into the nonLiveSet. A value is considered "non-live" if the corresponding
156 /// index in the `nonLive` bit vector is set.
157 static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
158                                  const BitVector &nonLive) {
159   for (auto [index, result] : llvm::enumerate(range)) {
160     if (!nonLive[index])
161       continue;
162     nonLiveSet.insert(result);
163   }
164 }
165 
166 /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
167 /// is 1.
168 static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
169   assert(op->getNumResults() == toErase.size() &&
170          "expected the number of results in `op` and the size of `toErase` to "
171          "be the same");
172 
173   std::vector<Type> newResultTypes;
174   for (OpResult result : op->getResults())
175     if (!toErase[result.getResultNumber()])
176       newResultTypes.push_back(result.getType());
177   OpBuilder builder(op);
178   builder.setInsertionPointAfter(op);
179   OperationState state(op->getLoc(), op->getName().getStringRef(),
180                        op->getOperands(), newResultTypes, op->getAttrs());
181   for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
182     state.addRegion();
183   Operation *newOp = builder.create(state);
184   for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
185     Region &newRegion = newOp->getRegion(index);
186     // Move all blocks of `region` into `newRegion`.
187     Block *temp = new Block();
188     newRegion.push_back(temp);
189     while (!region.empty())
190       region.front().moveBefore(temp);
191     temp->erase();
192   }
193 
194   unsigned indexOfNextNewCallOpResultToReplace = 0;
195   for (auto [index, result] : llvm::enumerate(op->getResults())) {
196     assert(result && "expected result to be non-null");
197     if (toErase[index]) {
198       result.dropAllUses();
199     } else {
200       result.replaceAllUsesWith(
201           newOp->getResult(indexOfNextNewCallOpResultToReplace++));
202     }
203   }
204   op->erase();
205 }
206 
207 /// Convert a list of `Operand`s to a list of `OpOperand`s.
208 static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
209   OpOperand *values = operands.getBase();
210   SmallVector<OpOperand *> opOperands;
211   for (unsigned i = 0, e = operands.size(); i < e; i++)
212     opOperands.push_back(&values[i]);
213   return opOperands;
214 }
215 
216 /// Process a simple operation `op` using the liveness analysis `la`.
217 /// If the operation has no memory effects and none of its results are live:
218 ///   1. Add the operation to a list for future removal, and
219 ///   2. Mark all its results as non-live values
220 ///
221 /// The operation `op` is assumed to be simple. A simple operation is one that
222 /// is NOT:
223 ///   - Function-like
224 ///   - Call-like
225 ///   - A region branch operation
226 ///   - A branch operation
227 ///   - A region branch terminator
228 ///   - Return-like
229 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
230                             DenseSet<Value> &nonLiveSet,
231                             RDVFinalCleanupList &cl) {
232   if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la))
233     return;
234 
235   cl.operations.push_back(op);
236   collectNonLiveValues(nonLiveSet, op->getResults(),
237                        BitVector(op->getNumResults(), true));
238 }
239 
240 /// Process a function-like operation `funcOp` using the liveness analysis `la`
241 /// and the IR in `module`. If it is not public or external:
242 ///   (1) Adding its non-live arguments to a list for future removal.
243 ///   (2) Marking their corresponding operands in its callers for removal.
244 ///   (3) Identifying and enqueueing unnecessary terminator operands
245 ///       (return values that are non-live across all callers) for removal.
246 ///   (4) Enqueueing the non-live arguments and return values for removal.
247 ///   (5) Collecting the uses of these return values in its callers for future
248 ///       removal.
249 ///   (6) Marking all its results as non-live values.
250 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
251                           RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
252                           RDVFinalCleanupList &cl) {
253   if (funcOp.isPublic() || funcOp.isExternal())
254     return;
255 
256   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
257   SmallVector<Value> arguments(funcOp.getArguments());
258   BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
259   nonLiveArgs = nonLiveArgs.flip();
260 
261   // Do (1).
262   for (auto [index, arg] : llvm::enumerate(arguments))
263     if (arg && nonLiveArgs[index]) {
264       cl.values.push_back(arg);
265       nonLiveSet.insert(arg);
266     }
267 
268   // Do (2).
269   SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
270   for (SymbolTable::SymbolUse use : uses) {
271     Operation *callOp = use.getUser();
272     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
273     // The number of operands in the call op may not match the number of
274     // arguments in the func op.
275     BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
276     SmallVector<OpOperand *> callOpOperands =
277         operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
278     for (int index : nonLiveArgs.set_bits())
279       nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
280     cl.operands.push_back({callOp, nonLiveCallOperands});
281   }
282 
283   // Do (3).
284   // Get the list of unnecessary terminator operands (return values that are
285   // non-live across all callers) in `nonLiveRets`. There is a very important
286   // subtlety here. Unnecessary terminator operands are NOT the operands of the
287   // terminator that are non-live. Instead, these are the return values of the
288   // callers such that a given return value is non-live across all callers. Such
289   // corresponding operands in the terminator could be live. An example to
290   // demonstrate this:
291   //  func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
292   //    %c0_i32 = arith.constant 0 : i32
293   //    %0 = arith.addi %c0_i32, %c0_i32 : i32
294   //    memref.store %0, %arg0[] : memref<i32>
295   //    return %c0_i32, %0 : i32, i32
296   //  }
297   //  func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
298   //    %1:2 = call @f(%arg1) : (memref<i32>) -> i32
299   //    return %1#0 : i32
300   //  }
301   // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
302   // need to return %0. But, %0 is live. And, still, we want to stop it from
303   // being returned, in order to optimize our IR. So, this demonstrates how we
304   // can make our optimization strong by even removing a live return value (%0),
305   // since it forwards only to non-live value(s) (%1#1).
306   Operation *lastReturnOp = funcOp.back().getTerminator();
307   size_t numReturns = lastReturnOp->getNumOperands();
308   BitVector nonLiveRets(numReturns, true);
309   for (SymbolTable::SymbolUse use : uses) {
310     Operation *callOp = use.getUser();
311     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
312     BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
313     nonLiveRets &= liveCallRets.flip();
314   }
315 
316   // Note that in the absence of control flow ops forcing the control to go from
317   // the entry (first) block to the other blocks, the control never reaches any
318   // block other than the entry block, because every block has a terminator.
319   for (Block &block : funcOp.getBlocks()) {
320     Operation *returnOp = block.getTerminator();
321     if (returnOp && returnOp->getNumOperands() == numReturns)
322       cl.operands.push_back({returnOp, nonLiveRets});
323   }
324 
325   // Do (4).
326   cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
327 
328   // Do (5) and (6).
329   for (SymbolTable::SymbolUse use : uses) {
330     Operation *callOp = use.getUser();
331     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
332     cl.results.push_back({callOp, nonLiveRets});
333     collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
334   }
335 }
336 
337 /// Process a region branch operation `regionBranchOp` using the liveness
338 /// information in `la`. The processing involves two scenarios:
339 ///
340 /// Scenario 1: If the operation has no memory effects and none of its results
341 /// are live:
342 ///   (1') Enqueue all its uses for deletion.
343 ///   (2') Enqueue the branch itself for deletion.
344 ///
345 /// Scenario 2: Otherwise:
346 ///   (1) Collect its unnecessary operands (operands forwarded to unnecessary
347 ///       results or arguments).
348 ///   (2) Process each of its regions.
349 ///   (3) Collect the uses of its unnecessary results (results forwarded from
350 ///       unnecessary operands
351 ///       or terminator operands).
352 ///   (4) Add these results to the deletion list.
353 ///
354 /// Processing a region includes:
355 ///   (a) Collecting the uses of its unnecessary arguments (arguments forwarded
356 ///       from unnecessary operands
357 ///       or terminator operands).
358 ///   (b) Collecting these unnecessary arguments.
359 ///   (c) Collecting its unnecessary terminator operands (terminator operands
360 ///       forwarded to unnecessary results
361 ///       or arguments).
362 ///
363 /// Value Flow Note: In this operation, values flow as follows:
364 /// - From operands and terminator operands (successor operands)
365 /// - To arguments and results (successor inputs).
366 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
367                                   RunLivenessAnalysis &la,
368                                   DenseSet<Value> &nonLiveSet,
369                                   RDVFinalCleanupList &cl) {
370   // Mark live results of `regionBranchOp` in `liveResults`.
371   auto markLiveResults = [&](BitVector &liveResults) {
372     liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
373   };
374 
375   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
376   auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
377     for (Region &region : regionBranchOp->getRegions()) {
378       SmallVector<Value> arguments(region.front().getArguments());
379       BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
380       liveArgs[&region] = regionLiveArgs;
381     }
382   };
383 
384   // Return the successors of `region` if the latter is not null. Else return
385   // the successors of `regionBranchOp`.
386   auto getSuccessors = [&](Region *region = nullptr) {
387     auto point = region ? region : RegionBranchPoint::parent();
388     SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
389                                              nullptr);
390     SmallVector<RegionSuccessor> successors;
391     regionBranchOp.getSuccessorRegions(point, successors);
392     return successors;
393   };
394 
395   // Return the operands of `terminator` that are forwarded to `successor` if
396   // the former is not null. Else return the operands of `regionBranchOp`
397   // forwarded to `successor`.
398   auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
399                                     Operation *terminator = nullptr) {
400     OperandRange operands =
401         terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
402                          .getSuccessorOperands(successor)
403                    : regionBranchOp.getEntrySuccessorOperands(successor);
404     SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
405     return opOperands;
406   };
407 
408   // Mark the non-forwarded operands of `regionBranchOp` in
409   // `nonForwardedOperands`.
410   auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
411     nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
412     for (const RegionSuccessor &successor : getSuccessors()) {
413       for (OpOperand *opOperand : getForwardedOpOperands(successor))
414         nonForwardedOperands.reset(opOperand->getOperandNumber());
415     }
416   };
417 
418   // Mark the non-forwarded terminator operands of the various regions of
419   // `regionBranchOp` in `nonForwardedRets`.
420   auto markNonForwardedReturnValues =
421       [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
422         for (Region &region : regionBranchOp->getRegions()) {
423           Operation *terminator = region.front().getTerminator();
424           nonForwardedRets[terminator] =
425               BitVector(terminator->getNumOperands(), true);
426           for (const RegionSuccessor &successor : getSuccessors(&region)) {
427             for (OpOperand *opOperand :
428                  getForwardedOpOperands(successor, terminator))
429               nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
430           }
431         }
432       };
433 
434   // Update `valuesToKeep` (which is expected to correspond to operands or
435   // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
436   // `region`. When `valuesToKeep` correspond to operands, `region` is null.
437   // Else, `region` is the parent region of the terminator.
438   auto updateOperandsOrTerminatorOperandsToKeep =
439       [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
440           DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
441         Operation *terminator =
442             region ? region->front().getTerminator() : nullptr;
443 
444         for (const RegionSuccessor &successor : getSuccessors(region)) {
445           Region *successorRegion = successor.getSuccessor();
446           for (auto [opOperand, input] :
447                llvm::zip(getForwardedOpOperands(successor, terminator),
448                          successor.getSuccessorInputs())) {
449             size_t operandNum = opOperand->getOperandNumber();
450             bool updateBasedOn =
451                 successorRegion
452                     ? argsToKeep[successorRegion]
453                                 [cast<BlockArgument>(input).getArgNumber()]
454                     : resultsToKeep[cast<OpResult>(input).getResultNumber()];
455             valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
456           }
457         }
458       };
459 
460   // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
461   // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
462   // value is modified, else, false.
463   auto recomputeResultsAndArgsToKeep =
464       [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
465           BitVector &operandsToKeep,
466           DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
467           bool &resultsOrArgsToKeepChanged) {
468         resultsOrArgsToKeepChanged = false;
469 
470         // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
471         for (const RegionSuccessor &successor : getSuccessors()) {
472           Region *successorRegion = successor.getSuccessor();
473           for (auto [opOperand, input] :
474                llvm::zip(getForwardedOpOperands(successor),
475                          successor.getSuccessorInputs())) {
476             bool recomputeBasedOn =
477                 operandsToKeep[opOperand->getOperandNumber()];
478             bool toRecompute =
479                 successorRegion
480                     ? argsToKeep[successorRegion]
481                                 [cast<BlockArgument>(input).getArgNumber()]
482                     : resultsToKeep[cast<OpResult>(input).getResultNumber()];
483             if (!toRecompute && recomputeBasedOn)
484               resultsOrArgsToKeepChanged = true;
485             if (successorRegion) {
486               argsToKeep[successorRegion][cast<BlockArgument>(input)
487                                               .getArgNumber()] =
488                   argsToKeep[successorRegion]
489                             [cast<BlockArgument>(input).getArgNumber()] |
490                   recomputeBasedOn;
491             } else {
492               resultsToKeep[cast<OpResult>(input).getResultNumber()] =
493                   resultsToKeep[cast<OpResult>(input).getResultNumber()] |
494                   recomputeBasedOn;
495             }
496           }
497         }
498 
499         // Recompute `resultsToKeep` and `argsToKeep` based on
500         // `terminatorOperandsToKeep`.
501         for (Region &region : regionBranchOp->getRegions()) {
502           Operation *terminator = region.front().getTerminator();
503           for (const RegionSuccessor &successor : getSuccessors(&region)) {
504             Region *successorRegion = successor.getSuccessor();
505             for (auto [opOperand, input] :
506                  llvm::zip(getForwardedOpOperands(successor, terminator),
507                            successor.getSuccessorInputs())) {
508               bool recomputeBasedOn =
509                   terminatorOperandsToKeep[region.back().getTerminator()]
510                                           [opOperand->getOperandNumber()];
511               bool toRecompute =
512                   successorRegion
513                       ? argsToKeep[successorRegion]
514                                   [cast<BlockArgument>(input).getArgNumber()]
515                       : resultsToKeep[cast<OpResult>(input).getResultNumber()];
516               if (!toRecompute && recomputeBasedOn)
517                 resultsOrArgsToKeepChanged = true;
518               if (successorRegion) {
519                 argsToKeep[successorRegion][cast<BlockArgument>(input)
520                                                 .getArgNumber()] =
521                     argsToKeep[successorRegion]
522                               [cast<BlockArgument>(input).getArgNumber()] |
523                     recomputeBasedOn;
524               } else {
525                 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
526                     resultsToKeep[cast<OpResult>(input).getResultNumber()] |
527                     recomputeBasedOn;
528               }
529             }
530           }
531         }
532       };
533 
534   // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
535   // `operandsToKeep`, and `terminatorOperandsToKeep`.
536   auto markValuesToKeep =
537       [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
538           BitVector &operandsToKeep,
539           DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
540         bool resultsOrArgsToKeepChanged = true;
541         // We keep updating and recomputing the values until we reach a point
542         // where they stop changing.
543         while (resultsOrArgsToKeepChanged) {
544           // Update the operands that need to be kept.
545           updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
546                                                    resultsToKeep, argsToKeep);
547 
548           // Update the terminator operands that need to be kept.
549           for (Region &region : regionBranchOp->getRegions()) {
550             updateOperandsOrTerminatorOperandsToKeep(
551                 terminatorOperandsToKeep[region.back().getTerminator()],
552                 resultsToKeep, argsToKeep, &region);
553           }
554 
555           // Recompute the results and arguments that need to be kept.
556           recomputeResultsAndArgsToKeep(
557               resultsToKeep, argsToKeep, operandsToKeep,
558               terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
559         }
560       };
561 
562   // Scenario 1. This is the only case where the entire `regionBranchOp`
563   // is removed. It will not happen in any other scenario. Note that in this
564   // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
565   // It could never be live because of this op but its liveness could have been
566   // attributed to something else.
567   // Do (1') and (2').
568   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
569       !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
570     cl.operations.push_back(regionBranchOp.getOperation());
571     return;
572   }
573 
574   // Scenario 2.
575   // At this point, we know that every non-forwarded operand of `regionBranchOp`
576   // is live.
577 
578   // Stores the results of `regionBranchOp` that we want to keep.
579   BitVector resultsToKeep;
580   // Stores the mapping from regions of `regionBranchOp` to their arguments that
581   // we want to keep.
582   DenseMap<Region *, BitVector> argsToKeep;
583   // Stores the operands of `regionBranchOp` that we want to keep.
584   BitVector operandsToKeep;
585   // Stores the mapping from region terminators in `regionBranchOp` to their
586   // operands that we want to keep.
587   DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
588 
589   // Initializing the above variables...
590 
591   // The live results of `regionBranchOp` definitely need to be kept.
592   markLiveResults(resultsToKeep);
593   // Similarly, the live arguments of the regions in `regionBranchOp` definitely
594   // need to be kept.
595   markLiveArgs(argsToKeep);
596   // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
597   // A live forwarded operand can be removed but no non-forwarded operand can be
598   // removed since it "controls" the flow of data in this control flow op.
599   markNonForwardedOperands(operandsToKeep);
600   // Similarly, the non-forwarded terminator operands of the regions in
601   // `regionBranchOp` definitely need to be kept.
602   markNonForwardedReturnValues(terminatorOperandsToKeep);
603 
604   // Mark the values (results, arguments, operands, and terminator operands)
605   // that we want to keep.
606   markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
607                    terminatorOperandsToKeep);
608 
609   // Do (1).
610   cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
611 
612   // Do (2.a) and (2.b).
613   for (Region &region : regionBranchOp->getRegions()) {
614     assert(!region.empty() && "expected a non-empty region in an op "
615                               "implementing `RegionBranchOpInterface`");
616     BitVector argsToRemove = argsToKeep[&region].flip();
617     cl.blocks.push_back({&region.front(), argsToRemove});
618     collectNonLiveValues(nonLiveSet, region.front().getArguments(),
619                          argsToRemove);
620   }
621 
622   // Do (2.c).
623   for (Region &region : regionBranchOp->getRegions()) {
624     Operation *terminator = region.front().getTerminator();
625     cl.operands.push_back(
626         {terminator, terminatorOperandsToKeep[terminator].flip()});
627   }
628 
629   // Do (3) and (4).
630   BitVector resultsToRemove = resultsToKeep.flip();
631   collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
632                        resultsToRemove);
633   cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
634 }
635 
636 /// Steps to process a `BranchOpInterface` operation:
637 /// Iterate through each successor block of `branchOp`.
638 /// (1) For each successor block, gather all operands from all successors.
639 /// (2) Fetch their associated liveness analysis data and collect for future
640 ///     removal.
641 /// (3) Identify and collect the dead operands from the successor block
642 ///     as well as their corresponding arguments.
643 
644 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
645                             DenseSet<Value> &nonLiveSet,
646                             RDVFinalCleanupList &cl) {
647   unsigned numSuccessors = branchOp->getNumSuccessors();
648 
649   for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
650     Block *successorBlock = branchOp->getSuccessor(succIdx);
651 
652     // Do (1)
653     SuccessorOperands successorOperands =
654         branchOp.getSuccessorOperands(succIdx);
655     SmallVector<Value> operandValues;
656     for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
657          ++operandIdx) {
658       operandValues.push_back(successorOperands[operandIdx]);
659     }
660 
661     // Do (2)
662     BitVector successorNonLive =
663         markLives(operandValues, nonLiveSet, la).flip();
664     collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
665                          successorNonLive);
666 
667     // Do (3)
668     cl.blocks.push_back({successorBlock, successorNonLive});
669     cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
670   }
671 }
672 
673 /// Removes dead values collected in RDVFinalCleanupList.
674 /// To be run once when all dead values have been collected.
675 static void cleanUpDeadVals(RDVFinalCleanupList &list) {
676   // 1. Operations
677   for (auto &op : list.operations) {
678     op->dropAllUses();
679     op->erase();
680   }
681 
682   // 2. Values
683   for (auto &v : list.values) {
684     v.dropAllUses();
685   }
686 
687   // 3. Functions
688   for (auto &f : list.functions) {
689     f.funcOp.eraseArguments(f.nonLiveArgs);
690     f.funcOp.eraseResults(f.nonLiveRets);
691   }
692 
693   // 4. Operands
694   for (auto &o : list.operands) {
695     o.op->eraseOperands(o.nonLive);
696   }
697 
698   // 5. Results
699   for (auto &r : list.results) {
700     dropUsesAndEraseResults(r.op, r.nonLive);
701   }
702 
703   // 6. Blocks
704   for (auto &b : list.blocks) {
705     // blocks that are accessed via multiple codepaths processed once
706     if (b.b->getNumArguments() != b.nonLiveArgs.size())
707       continue;
708     // it iterates backwards because erase invalidates all successor indexes
709     for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
710       if (!b.nonLiveArgs[i])
711         continue;
712       b.b->getArgument(i).dropAllUses();
713       b.b->eraseArgument(i);
714     }
715   }
716 
717   // 7. Successor Operands
718   for (auto &op : list.successorOperands) {
719     SuccessorOperands successorOperands =
720         op.branch.getSuccessorOperands(op.successorIndex);
721     // blocks that are accessed via multiple codepaths processed once
722     if (successorOperands.size() != op.nonLiveOperands.size())
723       continue;
724     // it iterates backwards because erase invalidates all successor indexes
725     for (int i = successorOperands.size() - 1; i >= 0; --i) {
726       if (!op.nonLiveOperands[i])
727         continue;
728       successorOperands.erase(i);
729     }
730   }
731 }
732 
733 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
734   void runOnOperation() override;
735 };
736 } // namespace
737 
738 void RemoveDeadValues::runOnOperation() {
739   auto &la = getAnalysis<RunLivenessAnalysis>();
740   Operation *module = getOperation();
741 
742   // Tracks values eligible for erasure - complements liveness analysis to
743   // identify "droppable" values.
744   DenseSet<Value> deadVals;
745 
746   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
747   // end of this pass.
748   RDVFinalCleanupList finalCleanupList;
749 
750   module->walk([&](Operation *op) {
751     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
752       processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
753     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
754       processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
755     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
756       processBranchOp(branchOp, la, deadVals, finalCleanupList);
757     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
758       // Nothing to do here because this is a terminator op and it should be
759       // honored with respect to its parent
760     } else if (isa<CallOpInterface>(op)) {
761       // Nothing to do because this op is associated with a function op and gets
762       // cleaned when the latter is cleaned.
763     } else {
764       processSimpleOp(op, la, deadVals, finalCleanupList);
765     }
766   });
767 
768   cleanUpDeadVals(finalCleanupList);
769 }
770 
771 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
772   return std::make_unique<RemoveDeadValues>();
773 }
774