1 //====----- OutlineShapeComputation.cpp -----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 #include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h" 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 #include "mlir/Dialect/Shape/Transforms/Passes.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/IRMapping.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "llvm/ADT/DenseSet.h" 20 #include "llvm/Support/Debug.h" 21 #include <queue> 22 #include <unordered_set> 23 #include <vector> 24 25 namespace mlir { 26 #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION 27 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" 28 } // namespace mlir 29 30 #define DEBUG_TYPE "outline-shape-computation" 31 32 using namespace mlir; 33 34 namespace { 35 36 // A Value is an input of the cluster if it is an operand of an operation in the 37 // cluster and its defining operation is not in the cluster. 38 SmallVector<Value, 4> 39 getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) { 40 SmallVector<Value, 4> inputs; 41 llvm::SmallDenseSet<Value> inputSet; 42 llvm::SmallDenseSet<Operation *> opSet; 43 for (Operation *op : cluster) { 44 bool inserted = opSet.insert(op).second; 45 (void)inserted; 46 assert(inserted && "cluster contains duplicate operations"); 47 } 48 49 for (Operation *op : cluster) { 50 for (Value operand : op->getOperands()) { 51 Operation *operandOp = operand.getDefiningOp(); 52 if (opSet.contains(operandOp)) { 53 // Skip if defining op is in the cluster. 54 continue; 55 } 56 if (inputSet.insert(operand).second) 57 inputs.push_back(operand); 58 } 59 } 60 return inputs; 61 } 62 63 // Create a shape.func representing the shape computation for `shape`. 64 std::pair<shape::FuncOp, SmallVector<Value>> 65 createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster, 66 Value shape, StringRef fnName, Location loc) { 67 SmallVector<Value, 4> inputs = getInputsOfCluster(cluster); 68 auto fnType = 69 cluster.empty() 70 ? b.getFunctionType(shape.getType(), shape.getType()) 71 : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType()); 72 shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType); 73 Block *block = fnOp.addEntryBlock(); 74 b.setInsertionPointToEnd(block); 75 IRMapping bvm; 76 if (cluster.empty()) { 77 bvm.map(shape, fnOp.getArgument(0)); 78 } else { 79 for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments())) 80 bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); 81 } 82 83 for (Operation *op : cluster) 84 b.clone(*op, bvm); 85 llvm::SmallVector<Value, 4> fnReturns; 86 fnReturns.push_back(bvm.lookupOrDefault(shape)); 87 88 b.create<shape::ReturnOp>(loc, fnReturns); 89 fnOp.setPrivate(); 90 return std::make_pair(fnOp, inputs); 91 } 92 93 // The operations in the cluster might be unsorted, which could be inconvenient 94 // when creating shape.func op. 95 DenseMap<Value, SmallVector<Operation *, 8>> 96 getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters, 97 func::FuncOp funcOp) { 98 // Compute all clusters that each operation is in 99 DenseMap<Operation *, SmallVector<Value>> op2Shapes; 100 for (const auto &it : clusters) { 101 Value shape = it.first; 102 const DenseSet<Operation *> &cluster = it.second; 103 for (Operation *cOp : cluster) 104 op2Shapes[cOp].push_back(shape); 105 } 106 107 // Iterate through all operations in order. Get all the clusters `cOp` belongs 108 // to and construct the new ordered cluster as it traverses. 109 DenseMap<Value, SmallVector<Operation *, 8>> orderedClusters; 110 funcOp.walk([&](Operation *op) { 111 auto it = op2Shapes.find(op); 112 if (it != op2Shapes.end()) { 113 Operation *cOp = it->first; 114 for (Value shape : it->second) 115 orderedClusters[shape].push_back(cOp); 116 } 117 }); 118 119 return orderedClusters; 120 } 121 122 void constructShapeFunc( 123 const std::vector<shape::WithOp> &allWithOps, MLIRContext *context, 124 DenseMap<Value, SmallVector<Operation *, 8>> &clusters, 125 SymbolTable &symbolTable, 126 DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc, 127 func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) { 128 std::string shapeCalculationNamePrefix = "shape_cal_"; 129 int shapeCalculationNameIdx = 0; 130 OpBuilder builder(context); 131 132 // Construct a shape function 133 for (shape::WithOp withOp : allWithOps) { 134 Value value = withOp.getOperand(); 135 Value shape = withOp.getShape(); 136 RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType()); 137 if (rankedType == nullptr) 138 continue; 139 140 const SmallVector<Operation *, 8> &cluster = clusters[shape]; 141 shape::ShapeMappingValue shapeMappingValue; 142 auto it = dynShape2ShapeFunc.find(shape); 143 if (it == dynShape2ShapeFunc.end()) { 144 std::string name = shapeCalculationNamePrefix + 145 std::to_string(shapeCalculationNameIdx++); 146 Location loc = value.getLoc(); 147 builder.setInsertionPointAfter(funcOp); 148 auto pair = createFuncFromCluster(builder, cluster, shape, name, loc); 149 const SmallVector<Value> &inputs = pair.second; 150 shape::FuncOp shapeFuncOp = pair.first; 151 StringAttr insertedName = symbolTable.insert(shapeFuncOp); 152 auto symbol = FlatSymbolRefAttr::get(context, insertedName); 153 154 shapeMappingValue.funcSymbol = symbol; 155 shapeMappingValue.inputs = inputs; 156 } else { 157 shapeMappingValue = it->second; 158 } 159 dynShape2ShapeFunc[shape] = shapeMappingValue; 160 shapeMappingAnalysis.shapeMapping.insert( 161 std::make_pair(value, shapeMappingValue)); 162 } 163 } 164 165 struct OutlineShapeComputationPass 166 : public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> { 167 168 void runOnOperation() override; 169 170 private: 171 bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput); 172 173 void getClusterFromValue(Value shape, 174 DenseMap<Value, DenseSet<Operation *>> &clusters); 175 176 DenseMap<Value, SmallVector<Operation *, 8>> 177 constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps, 178 func::FuncOp funcOp); 179 180 DenseSet<Operation *> onlyUsedByWithShapes; 181 }; 182 183 class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> { 184 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 185 186 LogicalResult matchAndRewrite(tensor::DimOp op, 187 PatternRewriter &rewriter) const override { 188 auto shapeOf = 189 rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource()); 190 rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf, 191 op.getIndex()); 192 return success(); 193 } 194 }; 195 196 void OutlineShapeComputationPass::runOnOperation() { 197 ModuleOp moduleOp = getOperation(); 198 SymbolTable symbolTable(moduleOp); 199 DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc; 200 auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>(); 201 // TODO: This is as we populate this analysis during a pass that mutates. This 202 // pass currently requires 1 single module being compiled. 203 shapeMappingAnalysis.shapeMapping.clear(); 204 markAnalysesPreserved<shape::ShapeMappingAnalysis>(); 205 206 moduleOp.walk([&](func::FuncOp funcOp) { 207 MLIRContext *context = funcOp.getContext(); 208 RewritePatternSet prevPatterns(context); 209 prevPatterns.insert<TensorDimOpRewriter>(context); 210 if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns)))) 211 return signalPassFailure(); 212 213 // initialize class member `onlyUsedByWithShapes` 214 onlyUsedByWithShapes.clear(); 215 funcOp.walk([&](Operation *op) { 216 calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr); 217 }); 218 LLVM_DEBUG({ 219 llvm::dbgs() << "onlyUsedByWithShapes table: \n"; 220 for (auto it : onlyUsedByWithShapes) 221 llvm::dbgs() << *it << "\n"; 222 }); 223 224 // collect all the shape.with_shape ops. 225 std::vector<shape::WithOp> allWithOps; 226 funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); }); 227 228 DenseMap<Value, SmallVector<Operation *, 8>> clusters = 229 constructClustersForEachShape(allWithOps, funcOp); 230 constructShapeFunc(allWithOps, context, clusters, symbolTable, 231 dynShape2ShapeFunc, funcOp, shapeMappingAnalysis); 232 233 for (shape::WithOp withOp : allWithOps) { 234 Value value = withOp.getOperand(); 235 for (Operation *user : 236 llvm::make_early_inc_range(withOp.getResult().getUsers())) { 237 if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) { 238 // For pattern like 239 // %1 = shape.with_shape %arg1, %0 240 // %2 = shape.value_of %1 241 // because shape.value doesn't care the shape, the shape.with_shape is 242 // redundant. 243 // If type of %arg1 and %2 has same type, just 244 // replaced %2 with %arg1. 245 // If type of %arg1 has different type like !shape.value_shape, 246 // transform into 247 // %2 = shape.value_of %arg1 248 if (valueOf.getType() == value.getType()) 249 valueOf.replaceAllUsesWith(value); 250 else 251 valueOf.setOperand(value); 252 } 253 } 254 } 255 256 // Apply patterns, note this also performs DCE. 257 if (failed(applyPatternsGreedily(funcOp, {}))) 258 return signalPassFailure(); 259 }); 260 } 261 262 DenseMap<Value, SmallVector<Operation *, 8>> 263 OutlineShapeComputationPass::constructClustersForEachShape( 264 const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) { 265 DenseMap<Value, DenseSet<Operation *>> clusters; 266 for (shape::WithOp withOp : allWithOps) { 267 Value shape = withOp.getShape(); 268 if (clusters.count(shape) == 0) 269 getClusterFromValue(shape, clusters); 270 } 271 return getOrderedClusters(clusters, funcOp); 272 } 273 274 // The output of a cluster is the `shape`, and the inputs are the outputs of 275 // operations who are not in `onlyUsedByWithShapes` 276 void OutlineShapeComputationPass::getClusterFromValue( 277 Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) { 278 DenseSet<Operation *> cluster; 279 280 DenseSet<Operation *> visited; 281 std::queue<Operation *> queue; 282 283 // defOp == nullptr means shape is the argument of the func op 284 if (Operation *defOp = shape.getDefiningOp()) { 285 visited.insert(defOp); 286 queue.push(defOp); 287 } 288 while (!queue.empty()) { 289 Operation *op = queue.front(); 290 queue.pop(); 291 if (onlyUsedByWithShapes.contains(op)) { 292 cluster.insert(op); 293 for (Value inp : op->getOperands()) { 294 Operation *inpDefOp = inp.getDefiningOp(); 295 if (nullptr != inpDefOp && visited.insert(inpDefOp).second) 296 queue.push(inpDefOp); 297 } 298 } 299 } 300 301 clusters[shape] = std::move(cluster); 302 } 303 304 // Returns whether `op` is a shape.with_shape, or all the users' of `op` 305 // eventually point to the shape operand of shape.with_shape ops 306 bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively( 307 Operation *op, Value prevOutput) { 308 if (onlyUsedByWithShapes.contains(op)) 309 return true; 310 311 if (auto withOp = llvm::dyn_cast<shape::WithOp>(op)) 312 return withOp.getShape() == prevOutput; 313 314 if (op->use_empty()) 315 return false; 316 317 for (Value oup : op->getResults()) 318 for (Operation *user : oup.getUsers()) 319 if (!calOnlyUsedByWithShapesRecursively(user, oup)) 320 return false; 321 322 onlyUsedByWithShapes.insert(op); 323 return true; 324 } 325 326 } // namespace 327 328 std::unique_ptr<OperationPass<ModuleOp>> 329 mlir::createOutlineShapeComputationPass() { 330 return std::make_unique<OutlineShapeComputationPass>(); 331 } 332