xref: /llvm-project/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
19f77909aSYuanqiang Liu //====----- OutlineShapeComputation.cpp -----------------------------------===//
29f77909aSYuanqiang Liu //
39f77909aSYuanqiang Liu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49f77909aSYuanqiang Liu // See https://llvm.org/LICENSE.txt for license information.
59f77909aSYuanqiang Liu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69f77909aSYuanqiang Liu //
79f77909aSYuanqiang Liu //===----------------------------------------------------------------------===//
89f77909aSYuanqiang Liu 
99f77909aSYuanqiang Liu #include "mlir/Dialect/Func/IR/FuncOps.h"
109f77909aSYuanqiang Liu #include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
119f77909aSYuanqiang Liu #include "mlir/Dialect/Shape/IR/Shape.h"
129f77909aSYuanqiang Liu #include "mlir/Dialect/Shape/Transforms/Passes.h"
139f77909aSYuanqiang Liu #include "mlir/Dialect/Tensor/IR/Tensor.h"
144d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
159f77909aSYuanqiang Liu #include "mlir/IR/Matchers.h"
169f77909aSYuanqiang Liu #include "mlir/Pass/Pass.h"
179f77909aSYuanqiang Liu #include "mlir/Transforms/DialectConversion.h"
189f77909aSYuanqiang Liu #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
199f77909aSYuanqiang Liu #include "llvm/ADT/DenseSet.h"
209f77909aSYuanqiang Liu #include "llvm/Support/Debug.h"
219f77909aSYuanqiang Liu #include <queue>
229f77909aSYuanqiang Liu #include <unordered_set>
239f77909aSYuanqiang Liu #include <vector>
249f77909aSYuanqiang Liu 
259f77909aSYuanqiang Liu namespace mlir {
269f77909aSYuanqiang Liu #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
279f77909aSYuanqiang Liu #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
289f77909aSYuanqiang Liu } // namespace mlir
299f77909aSYuanqiang Liu 
309f77909aSYuanqiang Liu #define DEBUG_TYPE "outline-shape-computation"
319f77909aSYuanqiang Liu 
329f77909aSYuanqiang Liu using namespace mlir;
339f77909aSYuanqiang Liu 
349f77909aSYuanqiang Liu namespace {
359f77909aSYuanqiang Liu 
369f77909aSYuanqiang Liu // A Value is an input of the cluster if it is an operand of an operation in the
379f77909aSYuanqiang Liu // cluster and its defining operation is not in the cluster.
389f77909aSYuanqiang Liu SmallVector<Value, 4>
399f77909aSYuanqiang Liu getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
409f77909aSYuanqiang Liu   SmallVector<Value, 4> inputs;
419f77909aSYuanqiang Liu   llvm::SmallDenseSet<Value> inputSet;
429f77909aSYuanqiang Liu   llvm::SmallDenseSet<Operation *> opSet;
439f77909aSYuanqiang Liu   for (Operation *op : cluster) {
449f77909aSYuanqiang Liu     bool inserted = opSet.insert(op).second;
459f77909aSYuanqiang Liu     (void)inserted;
469f77909aSYuanqiang Liu     assert(inserted && "cluster contains duplicate operations");
479f77909aSYuanqiang Liu   }
489f77909aSYuanqiang Liu 
499f77909aSYuanqiang Liu   for (Operation *op : cluster) {
509f77909aSYuanqiang Liu     for (Value operand : op->getOperands()) {
519f77909aSYuanqiang Liu       Operation *operandOp = operand.getDefiningOp();
5269ffd49cSKazu Hirata       if (opSet.contains(operandOp)) {
539f77909aSYuanqiang Liu         // Skip if defining op is in the cluster.
549f77909aSYuanqiang Liu         continue;
559f77909aSYuanqiang Liu       }
569f77909aSYuanqiang Liu       if (inputSet.insert(operand).second)
579f77909aSYuanqiang Liu         inputs.push_back(operand);
589f77909aSYuanqiang Liu     }
599f77909aSYuanqiang Liu   }
609f77909aSYuanqiang Liu   return inputs;
619f77909aSYuanqiang Liu }
629f77909aSYuanqiang Liu 
639f77909aSYuanqiang Liu // Create a shape.func representing the shape computation for `shape`.
649f77909aSYuanqiang Liu std::pair<shape::FuncOp, SmallVector<Value>>
659f77909aSYuanqiang Liu createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
669f77909aSYuanqiang Liu                       Value shape, StringRef fnName, Location loc) {
679f77909aSYuanqiang Liu   SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
689f77909aSYuanqiang Liu   auto fnType =
699f77909aSYuanqiang Liu       cluster.empty()
709f77909aSYuanqiang Liu           ? b.getFunctionType(shape.getType(), shape.getType())
719f77909aSYuanqiang Liu           : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
729f77909aSYuanqiang Liu   shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
739f77909aSYuanqiang Liu   Block *block = fnOp.addEntryBlock();
74b613a540SMatthias Springer   b.setInsertionPointToEnd(block);
754d67b278SJeff Niu   IRMapping bvm;
769f77909aSYuanqiang Liu   if (cluster.empty()) {
779f77909aSYuanqiang Liu     bvm.map(shape, fnOp.getArgument(0));
789f77909aSYuanqiang Liu   } else {
799f77909aSYuanqiang Liu     for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
809f77909aSYuanqiang Liu       bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
819f77909aSYuanqiang Liu   }
829f77909aSYuanqiang Liu 
839f77909aSYuanqiang Liu   for (Operation *op : cluster)
849f77909aSYuanqiang Liu     b.clone(*op, bvm);
859f77909aSYuanqiang Liu   llvm::SmallVector<Value, 4> fnReturns;
869f77909aSYuanqiang Liu   fnReturns.push_back(bvm.lookupOrDefault(shape));
879f77909aSYuanqiang Liu 
889f77909aSYuanqiang Liu   b.create<shape::ReturnOp>(loc, fnReturns);
899f77909aSYuanqiang Liu   fnOp.setPrivate();
909f77909aSYuanqiang Liu   return std::make_pair(fnOp, inputs);
919f77909aSYuanqiang Liu }
929f77909aSYuanqiang Liu 
939f77909aSYuanqiang Liu // The operations in the cluster might be unsorted, which could be inconvenient
949f77909aSYuanqiang Liu // when creating shape.func op.
959f77909aSYuanqiang Liu DenseMap<Value, SmallVector<Operation *, 8>>
969f77909aSYuanqiang Liu getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
979f77909aSYuanqiang Liu                    func::FuncOp funcOp) {
989f77909aSYuanqiang Liu   // Compute all clusters that each operation is in
999f77909aSYuanqiang Liu   DenseMap<Operation *, SmallVector<Value>> op2Shapes;
1009f77909aSYuanqiang Liu   for (const auto &it : clusters) {
1019f77909aSYuanqiang Liu     Value shape = it.first;
1029f77909aSYuanqiang Liu     const DenseSet<Operation *> &cluster = it.second;
1039f77909aSYuanqiang Liu     for (Operation *cOp : cluster)
1049f77909aSYuanqiang Liu       op2Shapes[cOp].push_back(shape);
1059f77909aSYuanqiang Liu   }
1069f77909aSYuanqiang Liu 
1079f77909aSYuanqiang Liu   // Iterate through all operations in order. Get all the clusters `cOp` belongs
1089f77909aSYuanqiang Liu   // to and construct the new ordered cluster as it traverses.
1099f77909aSYuanqiang Liu   DenseMap<Value, SmallVector<Operation *, 8>> orderedClusters;
1109f77909aSYuanqiang Liu   funcOp.walk([&](Operation *op) {
1119f77909aSYuanqiang Liu     auto it = op2Shapes.find(op);
1129f77909aSYuanqiang Liu     if (it != op2Shapes.end()) {
1139f77909aSYuanqiang Liu       Operation *cOp = it->first;
1149f77909aSYuanqiang Liu       for (Value shape : it->second)
1159f77909aSYuanqiang Liu         orderedClusters[shape].push_back(cOp);
1169f77909aSYuanqiang Liu     }
1179f77909aSYuanqiang Liu   });
1189f77909aSYuanqiang Liu 
1199f77909aSYuanqiang Liu   return orderedClusters;
1209f77909aSYuanqiang Liu }
1219f77909aSYuanqiang Liu 
1229f77909aSYuanqiang Liu void constructShapeFunc(
1239f77909aSYuanqiang Liu     const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
1249f77909aSYuanqiang Liu     DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
1259f77909aSYuanqiang Liu     SymbolTable &symbolTable,
1269f77909aSYuanqiang Liu     DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
1279f77909aSYuanqiang Liu     func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
1289f77909aSYuanqiang Liu   std::string shapeCalculationNamePrefix = "shape_cal_";
1299f77909aSYuanqiang Liu   int shapeCalculationNameIdx = 0;
1309f77909aSYuanqiang Liu   OpBuilder builder(context);
1319f77909aSYuanqiang Liu 
1329f77909aSYuanqiang Liu   // Construct a shape function
1339f77909aSYuanqiang Liu   for (shape::WithOp withOp : allWithOps) {
1349f77909aSYuanqiang Liu     Value value = withOp.getOperand();
1359f77909aSYuanqiang Liu     Value shape = withOp.getShape();
1365550c821STres Popp     RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
1379f77909aSYuanqiang Liu     if (rankedType == nullptr)
1389f77909aSYuanqiang Liu       continue;
1399f77909aSYuanqiang Liu 
1409f77909aSYuanqiang Liu     const SmallVector<Operation *, 8> &cluster = clusters[shape];
1419f77909aSYuanqiang Liu     shape::ShapeMappingValue shapeMappingValue;
1429f77909aSYuanqiang Liu     auto it = dynShape2ShapeFunc.find(shape);
1439f77909aSYuanqiang Liu     if (it == dynShape2ShapeFunc.end()) {
1449f77909aSYuanqiang Liu       std::string name = shapeCalculationNamePrefix +
1459f77909aSYuanqiang Liu                          std::to_string(shapeCalculationNameIdx++);
1469f77909aSYuanqiang Liu       Location loc = value.getLoc();
1479f77909aSYuanqiang Liu       builder.setInsertionPointAfter(funcOp);
1489f77909aSYuanqiang Liu       auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
1499f77909aSYuanqiang Liu       const SmallVector<Value> &inputs = pair.second;
1509f77909aSYuanqiang Liu       shape::FuncOp shapeFuncOp = pair.first;
1519f77909aSYuanqiang Liu       StringAttr insertedName = symbolTable.insert(shapeFuncOp);
1529f77909aSYuanqiang Liu       auto symbol = FlatSymbolRefAttr::get(context, insertedName);
1539f77909aSYuanqiang Liu 
1549f77909aSYuanqiang Liu       shapeMappingValue.funcSymbol = symbol;
1559f77909aSYuanqiang Liu       shapeMappingValue.inputs = inputs;
1569f77909aSYuanqiang Liu     } else {
1579f77909aSYuanqiang Liu       shapeMappingValue = it->second;
1589f77909aSYuanqiang Liu     }
1599f77909aSYuanqiang Liu     dynShape2ShapeFunc[shape] = shapeMappingValue;
1609f77909aSYuanqiang Liu     shapeMappingAnalysis.shapeMapping.insert(
1619f77909aSYuanqiang Liu         std::make_pair(value, shapeMappingValue));
1629f77909aSYuanqiang Liu   }
1639f77909aSYuanqiang Liu }
1649f77909aSYuanqiang Liu 
1659f77909aSYuanqiang Liu struct OutlineShapeComputationPass
1669f77909aSYuanqiang Liu     : public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
1679f77909aSYuanqiang Liu 
1689f77909aSYuanqiang Liu   void runOnOperation() override;
1699f77909aSYuanqiang Liu 
1709f77909aSYuanqiang Liu private:
1719f77909aSYuanqiang Liu   bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
1729f77909aSYuanqiang Liu 
1739f77909aSYuanqiang Liu   void getClusterFromValue(Value shape,
1749f77909aSYuanqiang Liu                            DenseMap<Value, DenseSet<Operation *>> &clusters);
1759f77909aSYuanqiang Liu 
1769f77909aSYuanqiang Liu   DenseMap<Value, SmallVector<Operation *, 8>>
1779f77909aSYuanqiang Liu   constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
1789f77909aSYuanqiang Liu                                 func::FuncOp funcOp);
1799f77909aSYuanqiang Liu 
1809f77909aSYuanqiang Liu   DenseSet<Operation *> onlyUsedByWithShapes;
1819f77909aSYuanqiang Liu };
1829f77909aSYuanqiang Liu 
1839f77909aSYuanqiang Liu class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1849f77909aSYuanqiang Liu   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1859f77909aSYuanqiang Liu 
1869f77909aSYuanqiang Liu   LogicalResult matchAndRewrite(tensor::DimOp op,
1879f77909aSYuanqiang Liu                                 PatternRewriter &rewriter) const override {
1889f77909aSYuanqiang Liu     auto shapeOf =
1899f77909aSYuanqiang Liu         rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
1909f77909aSYuanqiang Liu     rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
1919f77909aSYuanqiang Liu                                                     op.getIndex());
1929f77909aSYuanqiang Liu     return success();
1939f77909aSYuanqiang Liu   }
1949f77909aSYuanqiang Liu };
1959f77909aSYuanqiang Liu 
1969f77909aSYuanqiang Liu void OutlineShapeComputationPass::runOnOperation() {
1979f77909aSYuanqiang Liu   ModuleOp moduleOp = getOperation();
1989f77909aSYuanqiang Liu   SymbolTable symbolTable(moduleOp);
1999f77909aSYuanqiang Liu   DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc;
2009f77909aSYuanqiang Liu   auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
2019f77909aSYuanqiang Liu   // TODO: This is as we populate this analysis during a pass that mutates. This
2029f77909aSYuanqiang Liu   // pass currently requires 1 single module being compiled.
2039f77909aSYuanqiang Liu   shapeMappingAnalysis.shapeMapping.clear();
2049f77909aSYuanqiang Liu   markAnalysesPreserved<shape::ShapeMappingAnalysis>();
2059f77909aSYuanqiang Liu 
2069f77909aSYuanqiang Liu   moduleOp.walk([&](func::FuncOp funcOp) {
2079f77909aSYuanqiang Liu     MLIRContext *context = funcOp.getContext();
2089f77909aSYuanqiang Liu     RewritePatternSet prevPatterns(context);
2099f77909aSYuanqiang Liu     prevPatterns.insert<TensorDimOpRewriter>(context);
210*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns))))
2119f77909aSYuanqiang Liu       return signalPassFailure();
2129f77909aSYuanqiang Liu 
2139f77909aSYuanqiang Liu     // initialize class member `onlyUsedByWithShapes`
2149f77909aSYuanqiang Liu     onlyUsedByWithShapes.clear();
2159f77909aSYuanqiang Liu     funcOp.walk([&](Operation *op) {
2169f77909aSYuanqiang Liu       calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
2179f77909aSYuanqiang Liu     });
2189f77909aSYuanqiang Liu     LLVM_DEBUG({
2199f77909aSYuanqiang Liu       llvm::dbgs() << "onlyUsedByWithShapes table: \n";
2209f77909aSYuanqiang Liu       for (auto it : onlyUsedByWithShapes)
2219f77909aSYuanqiang Liu         llvm::dbgs() << *it << "\n";
2229f77909aSYuanqiang Liu     });
2239f77909aSYuanqiang Liu 
2249f77909aSYuanqiang Liu     // collect all the shape.with_shape ops.
2259f77909aSYuanqiang Liu     std::vector<shape::WithOp> allWithOps;
2269f77909aSYuanqiang Liu     funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
2279f77909aSYuanqiang Liu 
2289f77909aSYuanqiang Liu     DenseMap<Value, SmallVector<Operation *, 8>> clusters =
2299f77909aSYuanqiang Liu         constructClustersForEachShape(allWithOps, funcOp);
2309f77909aSYuanqiang Liu     constructShapeFunc(allWithOps, context, clusters, symbolTable,
2319f77909aSYuanqiang Liu                        dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
2329f77909aSYuanqiang Liu 
2339f77909aSYuanqiang Liu     for (shape::WithOp withOp : allWithOps) {
2349f77909aSYuanqiang Liu       Value value = withOp.getOperand();
235c3728d28SXiang Li       for (Operation *user :
236c3728d28SXiang Li            llvm::make_early_inc_range(withOp.getResult().getUsers())) {
237c3728d28SXiang Li         if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
238c3728d28SXiang Li           // For pattern like
239c3728d28SXiang Li           //   %1 = shape.with_shape %arg1, %0
240c3728d28SXiang Li           //   %2 = shape.value_of %1
241c3728d28SXiang Li           // because shape.value doesn't care the shape, the shape.with_shape is
242c3728d28SXiang Li           // redundant.
243c3728d28SXiang Li           // If type of %arg1 and %2 has same type, just
244c3728d28SXiang Li           //   replaced %2 with %arg1.
245c3728d28SXiang Li           // If type of %arg1 has different type like !shape.value_shape,
246c3728d28SXiang Li           // transform into
247c3728d28SXiang Li           //   %2 = shape.value_of %arg1
248c3728d28SXiang Li           if (valueOf.getType() == value.getType())
249c3728d28SXiang Li             valueOf.replaceAllUsesWith(value);
250c3728d28SXiang Li           else
251c3728d28SXiang Li             valueOf.setOperand(value);
252c3728d28SXiang Li         }
2539f77909aSYuanqiang Liu       }
2549f77909aSYuanqiang Liu     }
2559f77909aSYuanqiang Liu 
2569f77909aSYuanqiang Liu     // Apply patterns, note this also performs DCE.
257*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(funcOp, {})))
2589f77909aSYuanqiang Liu       return signalPassFailure();
2599f77909aSYuanqiang Liu   });
2609f77909aSYuanqiang Liu }
2619f77909aSYuanqiang Liu 
2629f77909aSYuanqiang Liu DenseMap<Value, SmallVector<Operation *, 8>>
2639f77909aSYuanqiang Liu OutlineShapeComputationPass::constructClustersForEachShape(
2649f77909aSYuanqiang Liu     const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
2659f77909aSYuanqiang Liu   DenseMap<Value, DenseSet<Operation *>> clusters;
2669f77909aSYuanqiang Liu   for (shape::WithOp withOp : allWithOps) {
2679f77909aSYuanqiang Liu     Value shape = withOp.getShape();
2689f77909aSYuanqiang Liu     if (clusters.count(shape) == 0)
2699f77909aSYuanqiang Liu       getClusterFromValue(shape, clusters);
2709f77909aSYuanqiang Liu   }
2719f77909aSYuanqiang Liu   return getOrderedClusters(clusters, funcOp);
2729f77909aSYuanqiang Liu }
2739f77909aSYuanqiang Liu 
2749f77909aSYuanqiang Liu // The output of a cluster is the `shape`, and the inputs are the outputs of
2759f77909aSYuanqiang Liu // operations who are not in `onlyUsedByWithShapes`
2769f77909aSYuanqiang Liu void OutlineShapeComputationPass::getClusterFromValue(
2779f77909aSYuanqiang Liu     Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
2789f77909aSYuanqiang Liu   DenseSet<Operation *> cluster;
2799f77909aSYuanqiang Liu 
2809f77909aSYuanqiang Liu   DenseSet<Operation *> visited;
2819f77909aSYuanqiang Liu   std::queue<Operation *> queue;
2829f77909aSYuanqiang Liu 
2839f77909aSYuanqiang Liu   // defOp == nullptr means shape is the argument of the func op
2849f77909aSYuanqiang Liu   if (Operation *defOp = shape.getDefiningOp()) {
2859f77909aSYuanqiang Liu     visited.insert(defOp);
2869f77909aSYuanqiang Liu     queue.push(defOp);
2879f77909aSYuanqiang Liu   }
2889f77909aSYuanqiang Liu   while (!queue.empty()) {
2899f77909aSYuanqiang Liu     Operation *op = queue.front();
2909f77909aSYuanqiang Liu     queue.pop();
2919f77909aSYuanqiang Liu     if (onlyUsedByWithShapes.contains(op)) {
2929f77909aSYuanqiang Liu       cluster.insert(op);
2939f77909aSYuanqiang Liu       for (Value inp : op->getOperands()) {
2949f77909aSYuanqiang Liu         Operation *inpDefOp = inp.getDefiningOp();
29565a5b18aSKazu Hirata         if (nullptr != inpDefOp && visited.insert(inpDefOp).second)
2969f77909aSYuanqiang Liu           queue.push(inpDefOp);
2979f77909aSYuanqiang Liu       }
2989f77909aSYuanqiang Liu     }
2999f77909aSYuanqiang Liu   }
3009f77909aSYuanqiang Liu 
3019f77909aSYuanqiang Liu   clusters[shape] = std::move(cluster);
3029f77909aSYuanqiang Liu }
3039f77909aSYuanqiang Liu 
3049f77909aSYuanqiang Liu // Returns whether `op` is a shape.with_shape, or all the users' of `op`
3059f77909aSYuanqiang Liu // eventually point to the shape operand of shape.with_shape ops
3069f77909aSYuanqiang Liu bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
3079f77909aSYuanqiang Liu     Operation *op, Value prevOutput) {
3089f77909aSYuanqiang Liu   if (onlyUsedByWithShapes.contains(op))
3099f77909aSYuanqiang Liu     return true;
3109f77909aSYuanqiang Liu 
3119f77909aSYuanqiang Liu   if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
3129f77909aSYuanqiang Liu     return withOp.getShape() == prevOutput;
3139f77909aSYuanqiang Liu 
3149f77909aSYuanqiang Liu   if (op->use_empty())
3159f77909aSYuanqiang Liu     return false;
3169f77909aSYuanqiang Liu 
3179f77909aSYuanqiang Liu   for (Value oup : op->getResults())
3189f77909aSYuanqiang Liu     for (Operation *user : oup.getUsers())
3199f77909aSYuanqiang Liu       if (!calOnlyUsedByWithShapesRecursively(user, oup))
3209f77909aSYuanqiang Liu         return false;
3219f77909aSYuanqiang Liu 
3229f77909aSYuanqiang Liu   onlyUsedByWithShapes.insert(op);
3239f77909aSYuanqiang Liu   return true;
3249f77909aSYuanqiang Liu }
3259f77909aSYuanqiang Liu 
3269f77909aSYuanqiang Liu } // namespace
3279f77909aSYuanqiang Liu 
3289f77909aSYuanqiang Liu std::unique_ptr<OperationPass<ModuleOp>>
3299f77909aSYuanqiang Liu mlir::createOutlineShapeComputationPass() {
3309f77909aSYuanqiang Liu   return std::make_unique<OutlineShapeComputationPass>();
3319f77909aSYuanqiang Liu }
332