1 //===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The goal of this pass is optimization (reducing runtime) by removing 10 // unnecessary instructions. Unlike other passes that rely on local information 11 // gathered from patterns to accomplish optimization, this pass uses a full 12 // analysis of the IR, specifically, liveness analysis, and is thus more 13 // powerful. 14 // 15 // Currently, this pass performs the following optimizations: 16 // (A) Removes function arguments that are not live, 17 // (B) Removes function return values that are not live across all callers of 18 // the function, 19 // (C) Removes unneccesary operands, results, region arguments, and region 20 // terminator operands of region branch ops, and, 21 // (D) Removes simple and region branch ops that have all non-live results and 22 // don't affect memory in any way, 23 // 24 // iff 25 // 26 // the IR doesn't have any non-function symbol ops, non-call symbol user ops and 27 // branch ops. 28 // 29 // Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op, 30 // region branch op, branch op, region branch terminator op, or return-like. 31 // 32 //===----------------------------------------------------------------------===// 33 34 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 35 #include "mlir/Analysis/DataFlow/LivenessAnalysis.h" 36 #include "mlir/IR/Attributes.h" 37 #include "mlir/IR/Builders.h" 38 #include "mlir/IR/BuiltinAttributes.h" 39 #include "mlir/IR/Dialect.h" 40 #include "mlir/IR/IRMapping.h" 41 #include "mlir/IR/OperationSupport.h" 42 #include "mlir/IR/SymbolTable.h" 43 #include "mlir/IR/Value.h" 44 #include "mlir/IR/ValueRange.h" 45 #include "mlir/IR/Visitors.h" 46 #include "mlir/Interfaces/CallInterfaces.h" 47 #include "mlir/Interfaces/ControlFlowInterfaces.h" 48 #include "mlir/Interfaces/FunctionInterfaces.h" 49 #include "mlir/Interfaces/SideEffectInterfaces.h" 50 #include "mlir/Pass/Pass.h" 51 #include "mlir/Support/LLVM.h" 52 #include "mlir/Transforms/FoldUtils.h" 53 #include "mlir/Transforms/Passes.h" 54 #include "llvm/ADT/STLExtras.h" 55 #include <cassert> 56 #include <cstddef> 57 #include <memory> 58 #include <optional> 59 #include <vector> 60 61 namespace mlir { 62 #define GEN_PASS_DEF_REMOVEDEADVALUES 63 #include "mlir/Transforms/Passes.h.inc" 64 } // namespace mlir 65 66 using namespace mlir; 67 using namespace mlir::dataflow; 68 69 //===----------------------------------------------------------------------===// 70 // RemoveDeadValues Pass 71 //===----------------------------------------------------------------------===// 72 73 namespace { 74 75 // Set of structures below to be filled with operations and arguments to erase. 76 // This is done to separate analysis and tree modification phases, 77 // otherwise analysis is operating on half-deleted tree which is incorrect. 78 79 struct FunctionToCleanUp { 80 FunctionOpInterface funcOp; 81 BitVector nonLiveArgs; 82 BitVector nonLiveRets; 83 }; 84 85 struct OperationToCleanup { 86 Operation *op; 87 BitVector nonLive; 88 }; 89 90 struct BlockArgsToCleanup { 91 Block *b; 92 BitVector nonLiveArgs; 93 }; 94 95 struct SuccessorOperandsToCleanup { 96 BranchOpInterface branch; 97 unsigned successorIndex; 98 BitVector nonLiveOperands; 99 }; 100 101 struct RDVFinalCleanupList { 102 SmallVector<Operation *> operations; 103 SmallVector<Value> values; 104 SmallVector<FunctionToCleanUp> functions; 105 SmallVector<OperationToCleanup> operands; 106 SmallVector<OperationToCleanup> results; 107 SmallVector<BlockArgsToCleanup> blocks; 108 SmallVector<SuccessorOperandsToCleanup> successorOperands; 109 }; 110 111 // Some helper functions... 112 113 /// Return true iff at least one value in `values` is live, given the liveness 114 /// information in `la`. 115 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, 116 RunLivenessAnalysis &la) { 117 for (Value value : values) { 118 if (nonLiveSet.contains(value)) 119 continue; 120 121 const Liveness *liveness = la.getLiveness(value); 122 if (!liveness || liveness->isLive) 123 return true; 124 } 125 return false; 126 } 127 128 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the 129 /// i-th value in `values` is live, given the liveness information in `la`. 130 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet, 131 RunLivenessAnalysis &la) { 132 BitVector lives(values.size(), true); 133 134 for (auto [index, value] : llvm::enumerate(values)) { 135 if (nonLiveSet.contains(value)) { 136 lives.reset(index); 137 continue; 138 } 139 140 const Liveness *liveness = la.getLiveness(value); 141 // It is important to note that when `liveness` is null, we can't tell if 142 // `value` is live or not. So, the safe option is to consider it live. Also, 143 // the execution of this pass might create new SSA values when erasing some 144 // of the results of an op and we know that these new values are live 145 // (because they weren't erased) and also their liveness is null because 146 // liveness analysis ran before their creation. 147 if (liveness && !liveness->isLive) 148 lives.reset(index); 149 } 150 151 return lives; 152 } 153 154 /// Collects values marked as "non-live" in the provided range and inserts them 155 /// into the nonLiveSet. A value is considered "non-live" if the corresponding 156 /// index in the `nonLive` bit vector is set. 157 static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range, 158 const BitVector &nonLive) { 159 for (auto [index, result] : llvm::enumerate(range)) { 160 if (!nonLive[index]) 161 continue; 162 nonLiveSet.insert(result); 163 } 164 } 165 166 /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i] 167 /// is 1. 168 static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { 169 assert(op->getNumResults() == toErase.size() && 170 "expected the number of results in `op` and the size of `toErase` to " 171 "be the same"); 172 173 std::vector<Type> newResultTypes; 174 for (OpResult result : op->getResults()) 175 if (!toErase[result.getResultNumber()]) 176 newResultTypes.push_back(result.getType()); 177 OpBuilder builder(op); 178 builder.setInsertionPointAfter(op); 179 OperationState state(op->getLoc(), op->getName().getStringRef(), 180 op->getOperands(), newResultTypes, op->getAttrs()); 181 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) 182 state.addRegion(); 183 Operation *newOp = builder.create(state); 184 for (const auto &[index, region] : llvm::enumerate(op->getRegions())) { 185 Region &newRegion = newOp->getRegion(index); 186 // Move all blocks of `region` into `newRegion`. 187 Block *temp = new Block(); 188 newRegion.push_back(temp); 189 while (!region.empty()) 190 region.front().moveBefore(temp); 191 temp->erase(); 192 } 193 194 unsigned indexOfNextNewCallOpResultToReplace = 0; 195 for (auto [index, result] : llvm::enumerate(op->getResults())) { 196 assert(result && "expected result to be non-null"); 197 if (toErase[index]) { 198 result.dropAllUses(); 199 } else { 200 result.replaceAllUsesWith( 201 newOp->getResult(indexOfNextNewCallOpResultToReplace++)); 202 } 203 } 204 op->erase(); 205 } 206 207 /// Convert a list of `Operand`s to a list of `OpOperand`s. 208 static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { 209 OpOperand *values = operands.getBase(); 210 SmallVector<OpOperand *> opOperands; 211 for (unsigned i = 0, e = operands.size(); i < e; i++) 212 opOperands.push_back(&values[i]); 213 return opOperands; 214 } 215 216 /// Process a simple operation `op` using the liveness analysis `la`. 217 /// If the operation has no memory effects and none of its results are live: 218 /// 1. Add the operation to a list for future removal, and 219 /// 2. Mark all its results as non-live values 220 /// 221 /// The operation `op` is assumed to be simple. A simple operation is one that 222 /// is NOT: 223 /// - Function-like 224 /// - Call-like 225 /// - A region branch operation 226 /// - A branch operation 227 /// - A region branch terminator 228 /// - Return-like 229 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, 230 DenseSet<Value> &nonLiveSet, 231 RDVFinalCleanupList &cl) { 232 if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) 233 return; 234 235 cl.operations.push_back(op); 236 collectNonLiveValues(nonLiveSet, op->getResults(), 237 BitVector(op->getNumResults(), true)); 238 } 239 240 /// Process a function-like operation `funcOp` using the liveness analysis `la` 241 /// and the IR in `module`. If it is not public or external: 242 /// (1) Adding its non-live arguments to a list for future removal. 243 /// (2) Marking their corresponding operands in its callers for removal. 244 /// (3) Identifying and enqueueing unnecessary terminator operands 245 /// (return values that are non-live across all callers) for removal. 246 /// (4) Enqueueing the non-live arguments and return values for removal. 247 /// (5) Collecting the uses of these return values in its callers for future 248 /// removal. 249 /// (6) Marking all its results as non-live values. 250 static void processFuncOp(FunctionOpInterface funcOp, Operation *module, 251 RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, 252 RDVFinalCleanupList &cl) { 253 if (funcOp.isPublic() || funcOp.isExternal()) 254 return; 255 256 // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`. 257 SmallVector<Value> arguments(funcOp.getArguments()); 258 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la); 259 nonLiveArgs = nonLiveArgs.flip(); 260 261 // Do (1). 262 for (auto [index, arg] : llvm::enumerate(arguments)) 263 if (arg && nonLiveArgs[index]) { 264 cl.values.push_back(arg); 265 nonLiveSet.insert(arg); 266 } 267 268 // Do (2). 269 SymbolTable::UseRange uses = *funcOp.getSymbolUses(module); 270 for (SymbolTable::SymbolUse use : uses) { 271 Operation *callOp = use.getUser(); 272 assert(isa<CallOpInterface>(callOp) && "expected a call-like user"); 273 // The number of operands in the call op may not match the number of 274 // arguments in the func op. 275 BitVector nonLiveCallOperands(callOp->getNumOperands(), false); 276 SmallVector<OpOperand *> callOpOperands = 277 operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands()); 278 for (int index : nonLiveArgs.set_bits()) 279 nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber()); 280 cl.operands.push_back({callOp, nonLiveCallOperands}); 281 } 282 283 // Do (3). 284 // Get the list of unnecessary terminator operands (return values that are 285 // non-live across all callers) in `nonLiveRets`. There is a very important 286 // subtlety here. Unnecessary terminator operands are NOT the operands of the 287 // terminator that are non-live. Instead, these are the return values of the 288 // callers such that a given return value is non-live across all callers. Such 289 // corresponding operands in the terminator could be live. An example to 290 // demonstrate this: 291 // func.func private @f(%arg0: memref<i32>) -> (i32, i32) { 292 // %c0_i32 = arith.constant 0 : i32 293 // %0 = arith.addi %c0_i32, %c0_i32 : i32 294 // memref.store %0, %arg0[] : memref<i32> 295 // return %c0_i32, %0 : i32, i32 296 // } 297 // func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) { 298 // %1:2 = call @f(%arg1) : (memref<i32>) -> i32 299 // return %1#0 : i32 300 // } 301 // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't 302 // need to return %0. But, %0 is live. And, still, we want to stop it from 303 // being returned, in order to optimize our IR. So, this demonstrates how we 304 // can make our optimization strong by even removing a live return value (%0), 305 // since it forwards only to non-live value(s) (%1#1). 306 Operation *lastReturnOp = funcOp.back().getTerminator(); 307 size_t numReturns = lastReturnOp->getNumOperands(); 308 BitVector nonLiveRets(numReturns, true); 309 for (SymbolTable::SymbolUse use : uses) { 310 Operation *callOp = use.getUser(); 311 assert(isa<CallOpInterface>(callOp) && "expected a call-like user"); 312 BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la); 313 nonLiveRets &= liveCallRets.flip(); 314 } 315 316 // Note that in the absence of control flow ops forcing the control to go from 317 // the entry (first) block to the other blocks, the control never reaches any 318 // block other than the entry block, because every block has a terminator. 319 for (Block &block : funcOp.getBlocks()) { 320 Operation *returnOp = block.getTerminator(); 321 if (returnOp && returnOp->getNumOperands() == numReturns) 322 cl.operands.push_back({returnOp, nonLiveRets}); 323 } 324 325 // Do (4). 326 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets}); 327 328 // Do (5) and (6). 329 for (SymbolTable::SymbolUse use : uses) { 330 Operation *callOp = use.getUser(); 331 assert(isa<CallOpInterface>(callOp) && "expected a call-like user"); 332 cl.results.push_back({callOp, nonLiveRets}); 333 collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets); 334 } 335 } 336 337 /// Process a region branch operation `regionBranchOp` using the liveness 338 /// information in `la`. The processing involves two scenarios: 339 /// 340 /// Scenario 1: If the operation has no memory effects and none of its results 341 /// are live: 342 /// (1') Enqueue all its uses for deletion. 343 /// (2') Enqueue the branch itself for deletion. 344 /// 345 /// Scenario 2: Otherwise: 346 /// (1) Collect its unnecessary operands (operands forwarded to unnecessary 347 /// results or arguments). 348 /// (2) Process each of its regions. 349 /// (3) Collect the uses of its unnecessary results (results forwarded from 350 /// unnecessary operands 351 /// or terminator operands). 352 /// (4) Add these results to the deletion list. 353 /// 354 /// Processing a region includes: 355 /// (a) Collecting the uses of its unnecessary arguments (arguments forwarded 356 /// from unnecessary operands 357 /// or terminator operands). 358 /// (b) Collecting these unnecessary arguments. 359 /// (c) Collecting its unnecessary terminator operands (terminator operands 360 /// forwarded to unnecessary results 361 /// or arguments). 362 /// 363 /// Value Flow Note: In this operation, values flow as follows: 364 /// - From operands and terminator operands (successor operands) 365 /// - To arguments and results (successor inputs). 366 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, 367 RunLivenessAnalysis &la, 368 DenseSet<Value> &nonLiveSet, 369 RDVFinalCleanupList &cl) { 370 // Mark live results of `regionBranchOp` in `liveResults`. 371 auto markLiveResults = [&](BitVector &liveResults) { 372 liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); 373 }; 374 375 // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. 376 auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) { 377 for (Region ®ion : regionBranchOp->getRegions()) { 378 SmallVector<Value> arguments(region.front().getArguments()); 379 BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la); 380 liveArgs[®ion] = regionLiveArgs; 381 } 382 }; 383 384 // Return the successors of `region` if the latter is not null. Else return 385 // the successors of `regionBranchOp`. 386 auto getSuccessors = [&](Region *region = nullptr) { 387 auto point = region ? region : RegionBranchPoint::parent(); 388 SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(), 389 nullptr); 390 SmallVector<RegionSuccessor> successors; 391 regionBranchOp.getSuccessorRegions(point, successors); 392 return successors; 393 }; 394 395 // Return the operands of `terminator` that are forwarded to `successor` if 396 // the former is not null. Else return the operands of `regionBranchOp` 397 // forwarded to `successor`. 398 auto getForwardedOpOperands = [&](const RegionSuccessor &successor, 399 Operation *terminator = nullptr) { 400 OperandRange operands = 401 terminator ? cast<RegionBranchTerminatorOpInterface>(terminator) 402 .getSuccessorOperands(successor) 403 : regionBranchOp.getEntrySuccessorOperands(successor); 404 SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands); 405 return opOperands; 406 }; 407 408 // Mark the non-forwarded operands of `regionBranchOp` in 409 // `nonForwardedOperands`. 410 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { 411 nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); 412 for (const RegionSuccessor &successor : getSuccessors()) { 413 for (OpOperand *opOperand : getForwardedOpOperands(successor)) 414 nonForwardedOperands.reset(opOperand->getOperandNumber()); 415 } 416 }; 417 418 // Mark the non-forwarded terminator operands of the various regions of 419 // `regionBranchOp` in `nonForwardedRets`. 420 auto markNonForwardedReturnValues = 421 [&](DenseMap<Operation *, BitVector> &nonForwardedRets) { 422 for (Region ®ion : regionBranchOp->getRegions()) { 423 Operation *terminator = region.front().getTerminator(); 424 nonForwardedRets[terminator] = 425 BitVector(terminator->getNumOperands(), true); 426 for (const RegionSuccessor &successor : getSuccessors(®ion)) { 427 for (OpOperand *opOperand : 428 getForwardedOpOperands(successor, terminator)) 429 nonForwardedRets[terminator].reset(opOperand->getOperandNumber()); 430 } 431 } 432 }; 433 434 // Update `valuesToKeep` (which is expected to correspond to operands or 435 // terminator operands) based on `resultsToKeep` and `argsToKeep`, given 436 // `region`. When `valuesToKeep` correspond to operands, `region` is null. 437 // Else, `region` is the parent region of the terminator. 438 auto updateOperandsOrTerminatorOperandsToKeep = 439 [&](BitVector &valuesToKeep, BitVector &resultsToKeep, 440 DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) { 441 Operation *terminator = 442 region ? region->front().getTerminator() : nullptr; 443 444 for (const RegionSuccessor &successor : getSuccessors(region)) { 445 Region *successorRegion = successor.getSuccessor(); 446 for (auto [opOperand, input] : 447 llvm::zip(getForwardedOpOperands(successor, terminator), 448 successor.getSuccessorInputs())) { 449 size_t operandNum = opOperand->getOperandNumber(); 450 bool updateBasedOn = 451 successorRegion 452 ? argsToKeep[successorRegion] 453 [cast<BlockArgument>(input).getArgNumber()] 454 : resultsToKeep[cast<OpResult>(input).getResultNumber()]; 455 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn; 456 } 457 } 458 }; 459 460 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and 461 // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a 462 // value is modified, else, false. 463 auto recomputeResultsAndArgsToKeep = 464 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep, 465 BitVector &operandsToKeep, 466 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep, 467 bool &resultsOrArgsToKeepChanged) { 468 resultsOrArgsToKeepChanged = false; 469 470 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. 471 for (const RegionSuccessor &successor : getSuccessors()) { 472 Region *successorRegion = successor.getSuccessor(); 473 for (auto [opOperand, input] : 474 llvm::zip(getForwardedOpOperands(successor), 475 successor.getSuccessorInputs())) { 476 bool recomputeBasedOn = 477 operandsToKeep[opOperand->getOperandNumber()]; 478 bool toRecompute = 479 successorRegion 480 ? argsToKeep[successorRegion] 481 [cast<BlockArgument>(input).getArgNumber()] 482 : resultsToKeep[cast<OpResult>(input).getResultNumber()]; 483 if (!toRecompute && recomputeBasedOn) 484 resultsOrArgsToKeepChanged = true; 485 if (successorRegion) { 486 argsToKeep[successorRegion][cast<BlockArgument>(input) 487 .getArgNumber()] = 488 argsToKeep[successorRegion] 489 [cast<BlockArgument>(input).getArgNumber()] | 490 recomputeBasedOn; 491 } else { 492 resultsToKeep[cast<OpResult>(input).getResultNumber()] = 493 resultsToKeep[cast<OpResult>(input).getResultNumber()] | 494 recomputeBasedOn; 495 } 496 } 497 } 498 499 // Recompute `resultsToKeep` and `argsToKeep` based on 500 // `terminatorOperandsToKeep`. 501 for (Region ®ion : regionBranchOp->getRegions()) { 502 Operation *terminator = region.front().getTerminator(); 503 for (const RegionSuccessor &successor : getSuccessors(®ion)) { 504 Region *successorRegion = successor.getSuccessor(); 505 for (auto [opOperand, input] : 506 llvm::zip(getForwardedOpOperands(successor, terminator), 507 successor.getSuccessorInputs())) { 508 bool recomputeBasedOn = 509 terminatorOperandsToKeep[region.back().getTerminator()] 510 [opOperand->getOperandNumber()]; 511 bool toRecompute = 512 successorRegion 513 ? argsToKeep[successorRegion] 514 [cast<BlockArgument>(input).getArgNumber()] 515 : resultsToKeep[cast<OpResult>(input).getResultNumber()]; 516 if (!toRecompute && recomputeBasedOn) 517 resultsOrArgsToKeepChanged = true; 518 if (successorRegion) { 519 argsToKeep[successorRegion][cast<BlockArgument>(input) 520 .getArgNumber()] = 521 argsToKeep[successorRegion] 522 [cast<BlockArgument>(input).getArgNumber()] | 523 recomputeBasedOn; 524 } else { 525 resultsToKeep[cast<OpResult>(input).getResultNumber()] = 526 resultsToKeep[cast<OpResult>(input).getResultNumber()] | 527 recomputeBasedOn; 528 } 529 } 530 } 531 } 532 }; 533 534 // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`, 535 // `operandsToKeep`, and `terminatorOperandsToKeep`. 536 auto markValuesToKeep = 537 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep, 538 BitVector &operandsToKeep, 539 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) { 540 bool resultsOrArgsToKeepChanged = true; 541 // We keep updating and recomputing the values until we reach a point 542 // where they stop changing. 543 while (resultsOrArgsToKeepChanged) { 544 // Update the operands that need to be kept. 545 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep, 546 resultsToKeep, argsToKeep); 547 548 // Update the terminator operands that need to be kept. 549 for (Region ®ion : regionBranchOp->getRegions()) { 550 updateOperandsOrTerminatorOperandsToKeep( 551 terminatorOperandsToKeep[region.back().getTerminator()], 552 resultsToKeep, argsToKeep, ®ion); 553 } 554 555 // Recompute the results and arguments that need to be kept. 556 recomputeResultsAndArgsToKeep( 557 resultsToKeep, argsToKeep, operandsToKeep, 558 terminatorOperandsToKeep, resultsOrArgsToKeepChanged); 559 } 560 }; 561 562 // Scenario 1. This is the only case where the entire `regionBranchOp` 563 // is removed. It will not happen in any other scenario. Note that in this 564 // case, a non-forwarded operand of `regionBranchOp` could be live/non-live. 565 // It could never be live because of this op but its liveness could have been 566 // attributed to something else. 567 // Do (1') and (2'). 568 if (isMemoryEffectFree(regionBranchOp.getOperation()) && 569 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) { 570 cl.operations.push_back(regionBranchOp.getOperation()); 571 return; 572 } 573 574 // Scenario 2. 575 // At this point, we know that every non-forwarded operand of `regionBranchOp` 576 // is live. 577 578 // Stores the results of `regionBranchOp` that we want to keep. 579 BitVector resultsToKeep; 580 // Stores the mapping from regions of `regionBranchOp` to their arguments that 581 // we want to keep. 582 DenseMap<Region *, BitVector> argsToKeep; 583 // Stores the operands of `regionBranchOp` that we want to keep. 584 BitVector operandsToKeep; 585 // Stores the mapping from region terminators in `regionBranchOp` to their 586 // operands that we want to keep. 587 DenseMap<Operation *, BitVector> terminatorOperandsToKeep; 588 589 // Initializing the above variables... 590 591 // The live results of `regionBranchOp` definitely need to be kept. 592 markLiveResults(resultsToKeep); 593 // Similarly, the live arguments of the regions in `regionBranchOp` definitely 594 // need to be kept. 595 markLiveArgs(argsToKeep); 596 // The non-forwarded operands of `regionBranchOp` definitely need to be kept. 597 // A live forwarded operand can be removed but no non-forwarded operand can be 598 // removed since it "controls" the flow of data in this control flow op. 599 markNonForwardedOperands(operandsToKeep); 600 // Similarly, the non-forwarded terminator operands of the regions in 601 // `regionBranchOp` definitely need to be kept. 602 markNonForwardedReturnValues(terminatorOperandsToKeep); 603 604 // Mark the values (results, arguments, operands, and terminator operands) 605 // that we want to keep. 606 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep, 607 terminatorOperandsToKeep); 608 609 // Do (1). 610 cl.operands.push_back({regionBranchOp, operandsToKeep.flip()}); 611 612 // Do (2.a) and (2.b). 613 for (Region ®ion : regionBranchOp->getRegions()) { 614 assert(!region.empty() && "expected a non-empty region in an op " 615 "implementing `RegionBranchOpInterface`"); 616 BitVector argsToRemove = argsToKeep[®ion].flip(); 617 cl.blocks.push_back({®ion.front(), argsToRemove}); 618 collectNonLiveValues(nonLiveSet, region.front().getArguments(), 619 argsToRemove); 620 } 621 622 // Do (2.c). 623 for (Region ®ion : regionBranchOp->getRegions()) { 624 Operation *terminator = region.front().getTerminator(); 625 cl.operands.push_back( 626 {terminator, terminatorOperandsToKeep[terminator].flip()}); 627 } 628 629 // Do (3) and (4). 630 BitVector resultsToRemove = resultsToKeep.flip(); 631 collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(), 632 resultsToRemove); 633 cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove}); 634 } 635 636 /// Steps to process a `BranchOpInterface` operation: 637 /// Iterate through each successor block of `branchOp`. 638 /// (1) For each successor block, gather all operands from all successors. 639 /// (2) Fetch their associated liveness analysis data and collect for future 640 /// removal. 641 /// (3) Identify and collect the dead operands from the successor block 642 /// as well as their corresponding arguments. 643 644 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, 645 DenseSet<Value> &nonLiveSet, 646 RDVFinalCleanupList &cl) { 647 unsigned numSuccessors = branchOp->getNumSuccessors(); 648 649 for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { 650 Block *successorBlock = branchOp->getSuccessor(succIdx); 651 652 // Do (1) 653 SuccessorOperands successorOperands = 654 branchOp.getSuccessorOperands(succIdx); 655 SmallVector<Value> operandValues; 656 for (unsigned operandIdx = 0; operandIdx < successorOperands.size(); 657 ++operandIdx) { 658 operandValues.push_back(successorOperands[operandIdx]); 659 } 660 661 // Do (2) 662 BitVector successorNonLive = 663 markLives(operandValues, nonLiveSet, la).flip(); 664 collectNonLiveValues(nonLiveSet, successorBlock->getArguments(), 665 successorNonLive); 666 667 // Do (3) 668 cl.blocks.push_back({successorBlock, successorNonLive}); 669 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive}); 670 } 671 } 672 673 /// Removes dead values collected in RDVFinalCleanupList. 674 /// To be run once when all dead values have been collected. 675 static void cleanUpDeadVals(RDVFinalCleanupList &list) { 676 // 1. Operations 677 for (auto &op : list.operations) { 678 op->dropAllUses(); 679 op->erase(); 680 } 681 682 // 2. Values 683 for (auto &v : list.values) { 684 v.dropAllUses(); 685 } 686 687 // 3. Functions 688 for (auto &f : list.functions) { 689 f.funcOp.eraseArguments(f.nonLiveArgs); 690 f.funcOp.eraseResults(f.nonLiveRets); 691 } 692 693 // 4. Operands 694 for (auto &o : list.operands) { 695 o.op->eraseOperands(o.nonLive); 696 } 697 698 // 5. Results 699 for (auto &r : list.results) { 700 dropUsesAndEraseResults(r.op, r.nonLive); 701 } 702 703 // 6. Blocks 704 for (auto &b : list.blocks) { 705 // blocks that are accessed via multiple codepaths processed once 706 if (b.b->getNumArguments() != b.nonLiveArgs.size()) 707 continue; 708 // it iterates backwards because erase invalidates all successor indexes 709 for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { 710 if (!b.nonLiveArgs[i]) 711 continue; 712 b.b->getArgument(i).dropAllUses(); 713 b.b->eraseArgument(i); 714 } 715 } 716 717 // 7. Successor Operands 718 for (auto &op : list.successorOperands) { 719 SuccessorOperands successorOperands = 720 op.branch.getSuccessorOperands(op.successorIndex); 721 // blocks that are accessed via multiple codepaths processed once 722 if (successorOperands.size() != op.nonLiveOperands.size()) 723 continue; 724 // it iterates backwards because erase invalidates all successor indexes 725 for (int i = successorOperands.size() - 1; i >= 0; --i) { 726 if (!op.nonLiveOperands[i]) 727 continue; 728 successorOperands.erase(i); 729 } 730 } 731 } 732 733 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> { 734 void runOnOperation() override; 735 }; 736 } // namespace 737 738 void RemoveDeadValues::runOnOperation() { 739 auto &la = getAnalysis<RunLivenessAnalysis>(); 740 Operation *module = getOperation(); 741 742 // Tracks values eligible for erasure - complements liveness analysis to 743 // identify "droppable" values. 744 DenseSet<Value> deadVals; 745 746 // Maintains a list of Ops, values, branches, etc., slated for cleanup at the 747 // end of this pass. 748 RDVFinalCleanupList finalCleanupList; 749 750 module->walk([&](Operation *op) { 751 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 752 processFuncOp(funcOp, module, la, deadVals, finalCleanupList); 753 } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) { 754 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList); 755 } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 756 processBranchOp(branchOp, la, deadVals, finalCleanupList); 757 } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) { 758 // Nothing to do here because this is a terminator op and it should be 759 // honored with respect to its parent 760 } else if (isa<CallOpInterface>(op)) { 761 // Nothing to do because this op is associated with a function op and gets 762 // cleaned when the latter is cleaned. 763 } else { 764 processSimpleOp(op, la, deadVals, finalCleanupList); 765 } 766 }); 767 768 cleanUpDeadVals(finalCleanupList); 769 } 770 771 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() { 772 return std::make_unique<RemoveDeadValues>(); 773 } 774