1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Analysis functions specific to slicing in Function. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Analysis/TopologicalSortUtils.h" 15 #include "mlir/IR/Block.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/Interfaces/SideEffectInterfaces.h" 18 #include "mlir/Support/LLVM.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include "llvm/ADT/SetVector.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 23 /// 24 /// Implements Analysis functions specific to slicing in Function. 25 /// 26 27 using namespace mlir; 28 29 static void 30 getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice, 31 const SliceOptions::TransitiveFilter &filter = nullptr) { 32 if (!op) 33 return; 34 35 // Evaluate whether we should keep this use. 36 // This is useful in particular to implement scoping; i.e. return the 37 // transitive forwardSlice in the current scope. 38 if (filter && !filter(op)) 39 return; 40 41 for (Region ®ion : op->getRegions()) 42 for (Block &block : region) 43 for (Operation &blockOp : block) 44 if (forwardSlice->count(&blockOp) == 0) 45 getForwardSliceImpl(&blockOp, forwardSlice, filter); 46 for (Value result : op->getResults()) { 47 for (Operation *userOp : result.getUsers()) 48 if (forwardSlice->count(userOp) == 0) 49 getForwardSliceImpl(userOp, forwardSlice, filter); 50 } 51 52 forwardSlice->insert(op); 53 } 54 55 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 56 const ForwardSliceOptions &options) { 57 getForwardSliceImpl(op, forwardSlice, options.filter); 58 if (!options.inclusive) { 59 // Don't insert the top level operation, we just queried on it and don't 60 // want it in the results. 61 forwardSlice->remove(op); 62 } 63 64 // Reverse to get back the actual topological order. 65 // std::reverse does not work out of the box on SetVector and I want an 66 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 67 SmallVector<Operation *, 0> v(forwardSlice->takeVector()); 68 forwardSlice->insert(v.rbegin(), v.rend()); 69 } 70 71 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, 72 const SliceOptions &options) { 73 for (Operation *user : root.getUsers()) 74 getForwardSliceImpl(user, forwardSlice, options.filter); 75 76 // Reverse to get back the actual topological order. 77 // std::reverse does not work out of the box on SetVector and I want an 78 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 79 SmallVector<Operation *, 0> v(forwardSlice->takeVector()); 80 forwardSlice->insert(v.rbegin(), v.rend()); 81 } 82 83 static void getBackwardSliceImpl(Operation *op, 84 SetVector<Operation *> *backwardSlice, 85 const BackwardSliceOptions &options) { 86 if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 87 return; 88 89 // Evaluate whether we should keep this def. 90 // This is useful in particular to implement scoping; i.e. return the 91 // transitive backwardSlice in the current scope. 92 if (options.filter && !options.filter(op)) 93 return; 94 95 auto processValue = [&](Value value) { 96 if (auto *definingOp = value.getDefiningOp()) { 97 if (backwardSlice->count(definingOp) == 0) 98 getBackwardSliceImpl(definingOp, backwardSlice, options); 99 } else if (auto blockArg = dyn_cast<BlockArgument>(value)) { 100 if (options.omitBlockArguments) 101 return; 102 103 Block *block = blockArg.getOwner(); 104 Operation *parentOp = block->getParentOp(); 105 // TODO: determine whether we want to recurse backward into the other 106 // blocks of parentOp, which are not technically backward unless they flow 107 // into us. For now, just bail. 108 if (parentOp && backwardSlice->count(parentOp) == 0) { 109 assert(parentOp->getNumRegions() == 1 && 110 parentOp->getRegion(0).getBlocks().size() == 1); 111 getBackwardSliceImpl(parentOp, backwardSlice, options); 112 } 113 } else { 114 llvm_unreachable("No definingOp and not a block argument."); 115 } 116 }; 117 118 if (!options.omitUsesFromAbove) { 119 llvm::for_each(op->getRegions(), [&](Region ®ion) { 120 // Walk this region recursively to collect the regions that descend from 121 // this op's nested regions (inclusive). 122 SmallPtrSet<Region *, 4> descendents; 123 region.walk( 124 [&](Region *childRegion) { descendents.insert(childRegion); }); 125 region.walk([&](Operation *op) { 126 for (OpOperand &operand : op->getOpOperands()) { 127 if (!descendents.contains(operand.get().getParentRegion())) 128 processValue(operand.get()); 129 } 130 }); 131 }); 132 } 133 llvm::for_each(op->getOperands(), processValue); 134 135 backwardSlice->insert(op); 136 } 137 138 void mlir::getBackwardSlice(Operation *op, 139 SetVector<Operation *> *backwardSlice, 140 const BackwardSliceOptions &options) { 141 getBackwardSliceImpl(op, backwardSlice, options); 142 143 if (!options.inclusive) { 144 // Don't insert the top level operation, we just queried on it and don't 145 // want it in the results. 146 backwardSlice->remove(op); 147 } 148 } 149 150 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice, 151 const BackwardSliceOptions &options) { 152 if (Operation *definingOp = root.getDefiningOp()) { 153 getBackwardSlice(definingOp, backwardSlice, options); 154 return; 155 } 156 Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp(); 157 getBackwardSlice(bbAargOwner, backwardSlice, options); 158 } 159 160 SetVector<Operation *> 161 mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, 162 const ForwardSliceOptions &forwardSliceOptions) { 163 SetVector<Operation *> slice; 164 slice.insert(op); 165 166 unsigned currentIndex = 0; 167 SetVector<Operation *> backwardSlice; 168 SetVector<Operation *> forwardSlice; 169 while (currentIndex != slice.size()) { 170 auto *currentOp = (slice)[currentIndex]; 171 // Compute and insert the backwardSlice starting from currentOp. 172 backwardSlice.clear(); 173 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); 174 slice.insert(backwardSlice.begin(), backwardSlice.end()); 175 176 // Compute and insert the forwardSlice starting from currentOp. 177 forwardSlice.clear(); 178 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); 179 slice.insert(forwardSlice.begin(), forwardSlice.end()); 180 ++currentIndex; 181 } 182 return topologicalSort(slice); 183 } 184 185 /// Returns true if `value` (transitively) depends on iteration-carried values 186 /// of the given `ancestorOp`. 187 static bool dependsOnCarriedVals(Value value, 188 ArrayRef<BlockArgument> iterCarriedArgs, 189 Operation *ancestorOp) { 190 // Compute the backward slice of the value. 191 SetVector<Operation *> slice; 192 BackwardSliceOptions sliceOptions; 193 sliceOptions.filter = [&](Operation *op) { 194 return !ancestorOp->isAncestor(op); 195 }; 196 getBackwardSlice(value, &slice, sliceOptions); 197 198 // Check that none of the operands of the operations in the backward slice are 199 // loop iteration arguments, and neither is the value itself. 200 SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(), 201 iterCarriedArgs.end()); 202 if (iterCarriedValSet.contains(value)) 203 return true; 204 205 for (Operation *op : slice) 206 for (Value operand : op->getOperands()) 207 if (iterCarriedValSet.contains(operand)) 208 return true; 209 210 return false; 211 } 212 213 /// Utility to match a generic reduction given a list of iteration-carried 214 /// arguments, `iterCarriedArgs` and the position of the potential reduction 215 /// argument within the list, `redPos`. If a reduction is matched, returns the 216 /// reduced value and the topologically-sorted list of combiner operations 217 /// involved in the reduction. Otherwise, returns a null value. 218 /// 219 /// The matching algorithm relies on the following invariants, which are subject 220 /// to change: 221 /// 1. The first combiner operation must be a binary operation with the 222 /// iteration-carried value and the reduced value as operands. 223 /// 2. The iteration-carried value and combiner operations must be side 224 /// effect-free, have single result and a single use. 225 /// 3. Combiner operations must be immediately nested in the region op 226 /// performing the reduction. 227 /// 4. Reduction def-use chain must end in a terminator op that yields the 228 /// next iteration/output values in the same order as the iteration-carried 229 /// values in `iterCarriedArgs`. 230 /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values 231 /// of the region op performing the reduction. 232 /// 233 /// This utility is generic enough to detect reductions involving multiple 234 /// combiner operations (disabled for now) across multiple dialects, including 235 /// Linalg, Affine and SCF. For the sake of genericity, it does not return 236 /// specific enum values for the combiner operations since its goal is also 237 /// matching reductions without pre-defined semantics in core MLIR. It's up to 238 /// each client to make sense out of the list of combiner operations. It's also 239 /// up to each client to check for additional invariants on the expected 240 /// reductions not covered by this generic matching. 241 Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs, 242 unsigned redPos, 243 SmallVectorImpl<Operation *> &combinerOps) { 244 assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds"); 245 246 BlockArgument redCarriedVal = iterCarriedArgs[redPos]; 247 if (!redCarriedVal.hasOneUse()) 248 return nullptr; 249 250 // For now, the first combiner op must be a binary op. 251 Operation *combinerOp = *redCarriedVal.getUsers().begin(); 252 if (combinerOp->getNumOperands() != 2) 253 return nullptr; 254 Value reducedVal = combinerOp->getOperand(0) == redCarriedVal 255 ? combinerOp->getOperand(1) 256 : combinerOp->getOperand(0); 257 258 Operation *redRegionOp = 259 iterCarriedArgs.front().getOwner()->getParent()->getParentOp(); 260 if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp)) 261 return nullptr; 262 263 // Traverse the def-use chain starting from the first combiner op until a 264 // terminator is found. Gather all the combiner ops along the way in 265 // topological order. 266 while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) { 267 if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 || 268 !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp) 269 return nullptr; 270 271 combinerOps.push_back(combinerOp); 272 combinerOp = *combinerOp->getUsers().begin(); 273 } 274 275 // Limit matching to single combiner op until we can properly test reductions 276 // involving multiple combiners. 277 if (combinerOps.size() != 1) 278 return nullptr; 279 280 // Check that the yielded value is in the same position as in 281 // `iterCarriedArgs`. 282 Operation *terminatorOp = combinerOp; 283 if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0]) 284 return nullptr; 285 286 return reducedVal; 287 } 288