1258dae5dSNicolas Vasilache //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 25c16564bSNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65c16564bSNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 85c16564bSNicolas Vasilache // 969d9e990SChris Lattner // This file implements Analysis functions specific to slicing in Function. 105c16564bSNicolas Vasilache // 115c16564bSNicolas Vasilache //===----------------------------------------------------------------------===// 125c16564bSNicolas Vasilache 135c16564bSNicolas Vasilache #include "mlir/Analysis/SliceAnalysis.h" 14b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h" 15b00e0c16SChristian Ulmann #include "mlir/IR/Block.h" 169ffdc930SRiver Riddle #include "mlir/IR/Operation.h" 17fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h" 180c3923e1SMehdi Amini #include "mlir/Support/LLVM.h" 19*d97bc388SIan Wood #include "llvm/ADT/STLExtras.h" 205c16564bSNicolas Vasilache #include "llvm/ADT/SetVector.h" 21755dc07dSRiver Riddle #include "llvm/ADT/SmallPtrSet.h" 225c16564bSNicolas Vasilache 235c16564bSNicolas Vasilache /// 2469d9e990SChris Lattner /// Implements Analysis functions specific to slicing in Function. 255c16564bSNicolas Vasilache /// 265c16564bSNicolas Vasilache 275c16564bSNicolas Vasilache using namespace mlir; 285c16564bSNicolas Vasilache 29641b12e9SMahesh Ravishankar static void 30641b12e9SMahesh Ravishankar getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice, 3106927848SMehdi Amini const SliceOptions::TransitiveFilter &filter = nullptr) { 32d01ea0edSNicolas Vasilache if (!op) 335c16564bSNicolas Vasilache return; 345c16564bSNicolas Vasilache 355c16564bSNicolas Vasilache // Evaluate whether we should keep this use. 365c16564bSNicolas Vasilache // This is useful in particular to implement scoping; i.e. return the 37258dae5dSNicolas Vasilache // transitive forwardSlice in the current scope. 38d01ea0edSNicolas Vasilache if (filter && !filter(op)) 395c16564bSNicolas Vasilache return; 405c16564bSNicolas Vasilache 41d01ea0edSNicolas Vasilache for (Region ®ion : op->getRegions()) 42d01ea0edSNicolas Vasilache for (Block &block : region) 43d01ea0edSNicolas Vasilache for (Operation &blockOp : block) 44d01ea0edSNicolas Vasilache if (forwardSlice->count(&blockOp) == 0) 45d01ea0edSNicolas Vasilache getForwardSliceImpl(&blockOp, forwardSlice, filter); 462f23270aSThomas Raoux for (Value result : op->getResults()) { 472f23270aSThomas Raoux for (Operation *userOp : result.getUsers()) 482f23270aSThomas Raoux if (forwardSlice->count(userOp) == 0) 492f23270aSThomas Raoux getForwardSliceImpl(userOp, forwardSlice, filter); 505c16564bSNicolas Vasilache } 515c16564bSNicolas Vasilache 529c085406SRiver Riddle forwardSlice->insert(op); 535c16564bSNicolas Vasilache } 54c3b0c6a0SNicolas Vasilache 559c085406SRiver Riddle void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 5606927848SMehdi Amini const ForwardSliceOptions &options) { 57641b12e9SMahesh Ravishankar getForwardSliceImpl(op, forwardSlice, options.filter); 58641b12e9SMahesh Ravishankar if (!options.inclusive) { 599c085406SRiver Riddle // Don't insert the top level operation, we just queried on it and don't 60c3b0c6a0SNicolas Vasilache // want it in the results. 619c085406SRiver Riddle forwardSlice->remove(op); 626d2501bfSNicolas Vasilache } 63c3b0c6a0SNicolas Vasilache 64c3b0c6a0SNicolas Vasilache // Reverse to get back the actual topological order. 65c3b0c6a0SNicolas Vasilache // std::reverse does not work out of the box on SetVector and I want an 66c3b0c6a0SNicolas Vasilache // in-place swap based thing (the real std::reverse, not the LLVM adapter). 671b162fabSFangrui Song SmallVector<Operation *, 0> v(forwardSlice->takeVector()); 68c3b0c6a0SNicolas Vasilache forwardSlice->insert(v.rbegin(), v.rend()); 695c16564bSNicolas Vasilache } 705c16564bSNicolas Vasilache 71d01ea0edSNicolas Vasilache void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, 7206927848SMehdi Amini const SliceOptions &options) { 73d01ea0edSNicolas Vasilache for (Operation *user : root.getUsers()) 74641b12e9SMahesh Ravishankar getForwardSliceImpl(user, forwardSlice, options.filter); 75d01ea0edSNicolas Vasilache 76d01ea0edSNicolas Vasilache // Reverse to get back the actual topological order. 77d01ea0edSNicolas Vasilache // std::reverse does not work out of the box on SetVector and I want an 78d01ea0edSNicolas Vasilache // in-place swap based thing (the real std::reverse, not the LLVM adapter). 791b162fabSFangrui Song SmallVector<Operation *, 0> v(forwardSlice->takeVector()); 80d01ea0edSNicolas Vasilache forwardSlice->insert(v.rbegin(), v.rend()); 81d01ea0edSNicolas Vasilache } 82d01ea0edSNicolas Vasilache 839c085406SRiver Riddle static void getBackwardSliceImpl(Operation *op, 849c085406SRiver Riddle SetVector<Operation *> *backwardSlice, 8506927848SMehdi Amini const BackwardSliceOptions &options) { 86d01ea0edSNicolas Vasilache if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 875c16564bSNicolas Vasilache return; 885c16564bSNicolas Vasilache 895c16564bSNicolas Vasilache // Evaluate whether we should keep this def. 905c16564bSNicolas Vasilache // This is useful in particular to implement scoping; i.e. return the 91d01ea0edSNicolas Vasilache // transitive backwardSlice in the current scope. 92641b12e9SMahesh Ravishankar if (options.filter && !options.filter(op)) 935c16564bSNicolas Vasilache return; 945c16564bSNicolas Vasilache 95*d97bc388SIan Wood auto processValue = [&](Value value) { 96*d97bc388SIan Wood if (auto *definingOp = value.getDefiningOp()) { 97d01ea0edSNicolas Vasilache if (backwardSlice->count(definingOp) == 0) 98641b12e9SMahesh Ravishankar getBackwardSliceImpl(definingOp, backwardSlice, options); 99*d97bc388SIan Wood } else if (auto blockArg = dyn_cast<BlockArgument>(value)) { 100641b12e9SMahesh Ravishankar if (options.omitBlockArguments) 101*d97bc388SIan Wood return; 102641b12e9SMahesh Ravishankar 103d01ea0edSNicolas Vasilache Block *block = blockArg.getOwner(); 104d01ea0edSNicolas Vasilache Operation *parentOp = block->getParentOp(); 105d01ea0edSNicolas Vasilache // TODO: determine whether we want to recurse backward into the other 106d01ea0edSNicolas Vasilache // blocks of parentOp, which are not technically backward unless they flow 107d01ea0edSNicolas Vasilache // into us. For now, just bail. 1088c33639aSHanhan Wang if (parentOp && backwardSlice->count(parentOp) == 0) { 109d01ea0edSNicolas Vasilache assert(parentOp->getNumRegions() == 1 && 110d01ea0edSNicolas Vasilache parentOp->getRegion(0).getBlocks().size() == 1); 111641b12e9SMahesh Ravishankar getBackwardSliceImpl(parentOp, backwardSlice, options); 1128c33639aSHanhan Wang } 113d01ea0edSNicolas Vasilache } else { 114d01ea0edSNicolas Vasilache llvm_unreachable("No definingOp and not a block argument."); 1155c16564bSNicolas Vasilache } 116*d97bc388SIan Wood }; 117*d97bc388SIan Wood 118*d97bc388SIan Wood if (!options.omitUsesFromAbove) { 119*d97bc388SIan Wood llvm::for_each(op->getRegions(), [&](Region ®ion) { 120*d97bc388SIan Wood // Walk this region recursively to collect the regions that descend from 121*d97bc388SIan Wood // this op's nested regions (inclusive). 122*d97bc388SIan Wood SmallPtrSet<Region *, 4> descendents; 123*d97bc388SIan Wood region.walk( 124*d97bc388SIan Wood [&](Region *childRegion) { descendents.insert(childRegion); }); 125*d97bc388SIan Wood region.walk([&](Operation *op) { 126*d97bc388SIan Wood for (OpOperand &operand : op->getOpOperands()) { 127*d97bc388SIan Wood if (!descendents.contains(operand.get().getParentRegion())) 128*d97bc388SIan Wood processValue(operand.get()); 1295c16564bSNicolas Vasilache } 130*d97bc388SIan Wood }); 131*d97bc388SIan Wood }); 132*d97bc388SIan Wood } 133*d97bc388SIan Wood llvm::for_each(op->getOperands(), processValue); 1345c16564bSNicolas Vasilache 1359c085406SRiver Riddle backwardSlice->insert(op); 1365c16564bSNicolas Vasilache } 137c3b0c6a0SNicolas Vasilache 1389c085406SRiver Riddle void mlir::getBackwardSlice(Operation *op, 1399c085406SRiver Riddle SetVector<Operation *> *backwardSlice, 14006927848SMehdi Amini const BackwardSliceOptions &options) { 141641b12e9SMahesh Ravishankar getBackwardSliceImpl(op, backwardSlice, options); 142c3b0c6a0SNicolas Vasilache 143641b12e9SMahesh Ravishankar if (!options.inclusive) { 1449c085406SRiver Riddle // Don't insert the top level operation, we just queried on it and don't 145c3b0c6a0SNicolas Vasilache // want it in the results. 1469c085406SRiver Riddle backwardSlice->remove(op); 1475c16564bSNicolas Vasilache } 1486d2501bfSNicolas Vasilache } 1495c16564bSNicolas Vasilache 150d01ea0edSNicolas Vasilache void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice, 15106927848SMehdi Amini const BackwardSliceOptions &options) { 152d01ea0edSNicolas Vasilache if (Operation *definingOp = root.getDefiningOp()) { 153641b12e9SMahesh Ravishankar getBackwardSlice(definingOp, backwardSlice, options); 154d01ea0edSNicolas Vasilache return; 155d01ea0edSNicolas Vasilache } 1565550c821STres Popp Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp(); 157641b12e9SMahesh Ravishankar getBackwardSlice(bbAargOwner, backwardSlice, options); 158d01ea0edSNicolas Vasilache } 159d01ea0edSNicolas Vasilache 16006927848SMehdi Amini SetVector<Operation *> 16106927848SMehdi Amini mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, 16206927848SMehdi Amini const ForwardSliceOptions &forwardSliceOptions) { 1639c085406SRiver Riddle SetVector<Operation *> slice; 1649c085406SRiver Riddle slice.insert(op); 1655c16564bSNicolas Vasilache 166258dae5dSNicolas Vasilache unsigned currentIndex = 0; 1679c085406SRiver Riddle SetVector<Operation *> backwardSlice; 1689c085406SRiver Riddle SetVector<Operation *> forwardSlice; 169258dae5dSNicolas Vasilache while (currentIndex != slice.size()) { 1706953cf65SNicolas Vasilache auto *currentOp = (slice)[currentIndex]; 1716953cf65SNicolas Vasilache // Compute and insert the backwardSlice starting from currentOp. 172258dae5dSNicolas Vasilache backwardSlice.clear(); 173641b12e9SMahesh Ravishankar getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); 174258dae5dSNicolas Vasilache slice.insert(backwardSlice.begin(), backwardSlice.end()); 1755c16564bSNicolas Vasilache 1766953cf65SNicolas Vasilache // Compute and insert the forwardSlice starting from currentOp. 177258dae5dSNicolas Vasilache forwardSlice.clear(); 178641b12e9SMahesh Ravishankar getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); 179258dae5dSNicolas Vasilache slice.insert(forwardSlice.begin(), forwardSlice.end()); 1805c16564bSNicolas Vasilache ++currentIndex; 1815c16564bSNicolas Vasilache } 182258dae5dSNicolas Vasilache return topologicalSort(slice); 1835c16564bSNicolas Vasilache } 1845c16564bSNicolas Vasilache 185755dc07dSRiver Riddle /// Returns true if `value` (transitively) depends on iteration-carried values 186755dc07dSRiver Riddle /// of the given `ancestorOp`. 187755dc07dSRiver Riddle static bool dependsOnCarriedVals(Value value, 188755dc07dSRiver Riddle ArrayRef<BlockArgument> iterCarriedArgs, 189755dc07dSRiver Riddle Operation *ancestorOp) { 190755dc07dSRiver Riddle // Compute the backward slice of the value. 191755dc07dSRiver Riddle SetVector<Operation *> slice; 192641b12e9SMahesh Ravishankar BackwardSliceOptions sliceOptions; 193641b12e9SMahesh Ravishankar sliceOptions.filter = [&](Operation *op) { 194641b12e9SMahesh Ravishankar return !ancestorOp->isAncestor(op); 195641b12e9SMahesh Ravishankar }; 196641b12e9SMahesh Ravishankar getBackwardSlice(value, &slice, sliceOptions); 197755dc07dSRiver Riddle 198755dc07dSRiver Riddle // Check that none of the operands of the operations in the backward slice are 199755dc07dSRiver Riddle // loop iteration arguments, and neither is the value itself. 200755dc07dSRiver Riddle SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(), 201755dc07dSRiver Riddle iterCarriedArgs.end()); 202755dc07dSRiver Riddle if (iterCarriedValSet.contains(value)) 203755dc07dSRiver Riddle return true; 204755dc07dSRiver Riddle 205755dc07dSRiver Riddle for (Operation *op : slice) 206755dc07dSRiver Riddle for (Value operand : op->getOperands()) 207755dc07dSRiver Riddle if (iterCarriedValSet.contains(operand)) 208755dc07dSRiver Riddle return true; 209755dc07dSRiver Riddle 210755dc07dSRiver Riddle return false; 211755dc07dSRiver Riddle } 212755dc07dSRiver Riddle 213755dc07dSRiver Riddle /// Utility to match a generic reduction given a list of iteration-carried 214755dc07dSRiver Riddle /// arguments, `iterCarriedArgs` and the position of the potential reduction 215755dc07dSRiver Riddle /// argument within the list, `redPos`. If a reduction is matched, returns the 216755dc07dSRiver Riddle /// reduced value and the topologically-sorted list of combiner operations 217755dc07dSRiver Riddle /// involved in the reduction. Otherwise, returns a null value. 218755dc07dSRiver Riddle /// 219755dc07dSRiver Riddle /// The matching algorithm relies on the following invariants, which are subject 220755dc07dSRiver Riddle /// to change: 221755dc07dSRiver Riddle /// 1. The first combiner operation must be a binary operation with the 222755dc07dSRiver Riddle /// iteration-carried value and the reduced value as operands. 223755dc07dSRiver Riddle /// 2. The iteration-carried value and combiner operations must be side 224755dc07dSRiver Riddle /// effect-free, have single result and a single use. 225755dc07dSRiver Riddle /// 3. Combiner operations must be immediately nested in the region op 226755dc07dSRiver Riddle /// performing the reduction. 227755dc07dSRiver Riddle /// 4. Reduction def-use chain must end in a terminator op that yields the 228755dc07dSRiver Riddle /// next iteration/output values in the same order as the iteration-carried 229755dc07dSRiver Riddle /// values in `iterCarriedArgs`. 230755dc07dSRiver Riddle /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values 231755dc07dSRiver Riddle /// of the region op performing the reduction. 232755dc07dSRiver Riddle /// 233755dc07dSRiver Riddle /// This utility is generic enough to detect reductions involving multiple 234755dc07dSRiver Riddle /// combiner operations (disabled for now) across multiple dialects, including 235755dc07dSRiver Riddle /// Linalg, Affine and SCF. For the sake of genericity, it does not return 236755dc07dSRiver Riddle /// specific enum values for the combiner operations since its goal is also 237755dc07dSRiver Riddle /// matching reductions without pre-defined semantics in core MLIR. It's up to 238755dc07dSRiver Riddle /// each client to make sense out of the list of combiner operations. It's also 239755dc07dSRiver Riddle /// up to each client to check for additional invariants on the expected 240755dc07dSRiver Riddle /// reductions not covered by this generic matching. 241755dc07dSRiver Riddle Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs, 242755dc07dSRiver Riddle unsigned redPos, 243755dc07dSRiver Riddle SmallVectorImpl<Operation *> &combinerOps) { 244755dc07dSRiver Riddle assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds"); 245755dc07dSRiver Riddle 246755dc07dSRiver Riddle BlockArgument redCarriedVal = iterCarriedArgs[redPos]; 247755dc07dSRiver Riddle if (!redCarriedVal.hasOneUse()) 248755dc07dSRiver Riddle return nullptr; 249755dc07dSRiver Riddle 250755dc07dSRiver Riddle // For now, the first combiner op must be a binary op. 251755dc07dSRiver Riddle Operation *combinerOp = *redCarriedVal.getUsers().begin(); 252755dc07dSRiver Riddle if (combinerOp->getNumOperands() != 2) 253755dc07dSRiver Riddle return nullptr; 254755dc07dSRiver Riddle Value reducedVal = combinerOp->getOperand(0) == redCarriedVal 255755dc07dSRiver Riddle ? combinerOp->getOperand(1) 256755dc07dSRiver Riddle : combinerOp->getOperand(0); 257755dc07dSRiver Riddle 258755dc07dSRiver Riddle Operation *redRegionOp = 259755dc07dSRiver Riddle iterCarriedArgs.front().getOwner()->getParent()->getParentOp(); 260755dc07dSRiver Riddle if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp)) 261755dc07dSRiver Riddle return nullptr; 262755dc07dSRiver Riddle 263755dc07dSRiver Riddle // Traverse the def-use chain starting from the first combiner op until a 264755dc07dSRiver Riddle // terminator is found. Gather all the combiner ops along the way in 265755dc07dSRiver Riddle // topological order. 266755dc07dSRiver Riddle while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) { 267fc367dfaSMahesh Ravishankar if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 || 268fc367dfaSMahesh Ravishankar !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp) 269755dc07dSRiver Riddle return nullptr; 270755dc07dSRiver Riddle 271755dc07dSRiver Riddle combinerOps.push_back(combinerOp); 272755dc07dSRiver Riddle combinerOp = *combinerOp->getUsers().begin(); 273755dc07dSRiver Riddle } 274755dc07dSRiver Riddle 275755dc07dSRiver Riddle // Limit matching to single combiner op until we can properly test reductions 276755dc07dSRiver Riddle // involving multiple combiners. 277755dc07dSRiver Riddle if (combinerOps.size() != 1) 278755dc07dSRiver Riddle return nullptr; 279755dc07dSRiver Riddle 280755dc07dSRiver Riddle // Check that the yielded value is in the same position as in 281755dc07dSRiver Riddle // `iterCarriedArgs`. 282755dc07dSRiver Riddle Operation *terminatorOp = combinerOp; 283755dc07dSRiver Riddle if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0]) 284755dc07dSRiver Riddle return nullptr; 285755dc07dSRiver Riddle 286755dc07dSRiver Riddle return reducedVal; 287755dc07dSRiver Riddle } 288