xref: /llvm-project/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //====----- OutlineShapeComputation.cpp -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 #include "mlir/Dialect/Shape/Transforms/Passes.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/Support/Debug.h"
21 #include <queue>
22 #include <unordered_set>
23 #include <vector>
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
27 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 #define DEBUG_TYPE "outline-shape-computation"
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 // A Value is an input of the cluster if it is an operand of an operation in the
37 // cluster and its defining operation is not in the cluster.
38 SmallVector<Value, 4>
39 getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
40   SmallVector<Value, 4> inputs;
41   llvm::SmallDenseSet<Value> inputSet;
42   llvm::SmallDenseSet<Operation *> opSet;
43   for (Operation *op : cluster) {
44     bool inserted = opSet.insert(op).second;
45     (void)inserted;
46     assert(inserted && "cluster contains duplicate operations");
47   }
48 
49   for (Operation *op : cluster) {
50     for (Value operand : op->getOperands()) {
51       Operation *operandOp = operand.getDefiningOp();
52       if (opSet.contains(operandOp)) {
53         // Skip if defining op is in the cluster.
54         continue;
55       }
56       if (inputSet.insert(operand).second)
57         inputs.push_back(operand);
58     }
59   }
60   return inputs;
61 }
62 
63 // Create a shape.func representing the shape computation for `shape`.
64 std::pair<shape::FuncOp, SmallVector<Value>>
65 createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
66                       Value shape, StringRef fnName, Location loc) {
67   SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
68   auto fnType =
69       cluster.empty()
70           ? b.getFunctionType(shape.getType(), shape.getType())
71           : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
72   shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
73   Block *block = fnOp.addEntryBlock();
74   b.setInsertionPointToEnd(block);
75   IRMapping bvm;
76   if (cluster.empty()) {
77     bvm.map(shape, fnOp.getArgument(0));
78   } else {
79     for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
80       bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
81   }
82 
83   for (Operation *op : cluster)
84     b.clone(*op, bvm);
85   llvm::SmallVector<Value, 4> fnReturns;
86   fnReturns.push_back(bvm.lookupOrDefault(shape));
87 
88   b.create<shape::ReturnOp>(loc, fnReturns);
89   fnOp.setPrivate();
90   return std::make_pair(fnOp, inputs);
91 }
92 
93 // The operations in the cluster might be unsorted, which could be inconvenient
94 // when creating shape.func op.
95 DenseMap<Value, SmallVector<Operation *, 8>>
96 getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
97                    func::FuncOp funcOp) {
98   // Compute all clusters that each operation is in
99   DenseMap<Operation *, SmallVector<Value>> op2Shapes;
100   for (const auto &it : clusters) {
101     Value shape = it.first;
102     const DenseSet<Operation *> &cluster = it.second;
103     for (Operation *cOp : cluster)
104       op2Shapes[cOp].push_back(shape);
105   }
106 
107   // Iterate through all operations in order. Get all the clusters `cOp` belongs
108   // to and construct the new ordered cluster as it traverses.
109   DenseMap<Value, SmallVector<Operation *, 8>> orderedClusters;
110   funcOp.walk([&](Operation *op) {
111     auto it = op2Shapes.find(op);
112     if (it != op2Shapes.end()) {
113       Operation *cOp = it->first;
114       for (Value shape : it->second)
115         orderedClusters[shape].push_back(cOp);
116     }
117   });
118 
119   return orderedClusters;
120 }
121 
122 void constructShapeFunc(
123     const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
124     DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
125     SymbolTable &symbolTable,
126     DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
127     func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
128   std::string shapeCalculationNamePrefix = "shape_cal_";
129   int shapeCalculationNameIdx = 0;
130   OpBuilder builder(context);
131 
132   // Construct a shape function
133   for (shape::WithOp withOp : allWithOps) {
134     Value value = withOp.getOperand();
135     Value shape = withOp.getShape();
136     RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
137     if (rankedType == nullptr)
138       continue;
139 
140     const SmallVector<Operation *, 8> &cluster = clusters[shape];
141     shape::ShapeMappingValue shapeMappingValue;
142     auto it = dynShape2ShapeFunc.find(shape);
143     if (it == dynShape2ShapeFunc.end()) {
144       std::string name = shapeCalculationNamePrefix +
145                          std::to_string(shapeCalculationNameIdx++);
146       Location loc = value.getLoc();
147       builder.setInsertionPointAfter(funcOp);
148       auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
149       const SmallVector<Value> &inputs = pair.second;
150       shape::FuncOp shapeFuncOp = pair.first;
151       StringAttr insertedName = symbolTable.insert(shapeFuncOp);
152       auto symbol = FlatSymbolRefAttr::get(context, insertedName);
153 
154       shapeMappingValue.funcSymbol = symbol;
155       shapeMappingValue.inputs = inputs;
156     } else {
157       shapeMappingValue = it->second;
158     }
159     dynShape2ShapeFunc[shape] = shapeMappingValue;
160     shapeMappingAnalysis.shapeMapping.insert(
161         std::make_pair(value, shapeMappingValue));
162   }
163 }
164 
165 struct OutlineShapeComputationPass
166     : public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
167 
168   void runOnOperation() override;
169 
170 private:
171   bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
172 
173   void getClusterFromValue(Value shape,
174                            DenseMap<Value, DenseSet<Operation *>> &clusters);
175 
176   DenseMap<Value, SmallVector<Operation *, 8>>
177   constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
178                                 func::FuncOp funcOp);
179 
180   DenseSet<Operation *> onlyUsedByWithShapes;
181 };
182 
183 class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
184   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
185 
186   LogicalResult matchAndRewrite(tensor::DimOp op,
187                                 PatternRewriter &rewriter) const override {
188     auto shapeOf =
189         rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
190     rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
191                                                     op.getIndex());
192     return success();
193   }
194 };
195 
196 void OutlineShapeComputationPass::runOnOperation() {
197   ModuleOp moduleOp = getOperation();
198   SymbolTable symbolTable(moduleOp);
199   DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc;
200   auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
201   // TODO: This is as we populate this analysis during a pass that mutates. This
202   // pass currently requires 1 single module being compiled.
203   shapeMappingAnalysis.shapeMapping.clear();
204   markAnalysesPreserved<shape::ShapeMappingAnalysis>();
205 
206   moduleOp.walk([&](func::FuncOp funcOp) {
207     MLIRContext *context = funcOp.getContext();
208     RewritePatternSet prevPatterns(context);
209     prevPatterns.insert<TensorDimOpRewriter>(context);
210     if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns))))
211       return signalPassFailure();
212 
213     // initialize class member `onlyUsedByWithShapes`
214     onlyUsedByWithShapes.clear();
215     funcOp.walk([&](Operation *op) {
216       calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
217     });
218     LLVM_DEBUG({
219       llvm::dbgs() << "onlyUsedByWithShapes table: \n";
220       for (auto it : onlyUsedByWithShapes)
221         llvm::dbgs() << *it << "\n";
222     });
223 
224     // collect all the shape.with_shape ops.
225     std::vector<shape::WithOp> allWithOps;
226     funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
227 
228     DenseMap<Value, SmallVector<Operation *, 8>> clusters =
229         constructClustersForEachShape(allWithOps, funcOp);
230     constructShapeFunc(allWithOps, context, clusters, symbolTable,
231                        dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
232 
233     for (shape::WithOp withOp : allWithOps) {
234       Value value = withOp.getOperand();
235       for (Operation *user :
236            llvm::make_early_inc_range(withOp.getResult().getUsers())) {
237         if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
238           // For pattern like
239           //   %1 = shape.with_shape %arg1, %0
240           //   %2 = shape.value_of %1
241           // because shape.value doesn't care the shape, the shape.with_shape is
242           // redundant.
243           // If type of %arg1 and %2 has same type, just
244           //   replaced %2 with %arg1.
245           // If type of %arg1 has different type like !shape.value_shape,
246           // transform into
247           //   %2 = shape.value_of %arg1
248           if (valueOf.getType() == value.getType())
249             valueOf.replaceAllUsesWith(value);
250           else
251             valueOf.setOperand(value);
252         }
253       }
254     }
255 
256     // Apply patterns, note this also performs DCE.
257     if (failed(applyPatternsGreedily(funcOp, {})))
258       return signalPassFailure();
259   });
260 }
261 
262 DenseMap<Value, SmallVector<Operation *, 8>>
263 OutlineShapeComputationPass::constructClustersForEachShape(
264     const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
265   DenseMap<Value, DenseSet<Operation *>> clusters;
266   for (shape::WithOp withOp : allWithOps) {
267     Value shape = withOp.getShape();
268     if (clusters.count(shape) == 0)
269       getClusterFromValue(shape, clusters);
270   }
271   return getOrderedClusters(clusters, funcOp);
272 }
273 
274 // The output of a cluster is the `shape`, and the inputs are the outputs of
275 // operations who are not in `onlyUsedByWithShapes`
276 void OutlineShapeComputationPass::getClusterFromValue(
277     Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
278   DenseSet<Operation *> cluster;
279 
280   DenseSet<Operation *> visited;
281   std::queue<Operation *> queue;
282 
283   // defOp == nullptr means shape is the argument of the func op
284   if (Operation *defOp = shape.getDefiningOp()) {
285     visited.insert(defOp);
286     queue.push(defOp);
287   }
288   while (!queue.empty()) {
289     Operation *op = queue.front();
290     queue.pop();
291     if (onlyUsedByWithShapes.contains(op)) {
292       cluster.insert(op);
293       for (Value inp : op->getOperands()) {
294         Operation *inpDefOp = inp.getDefiningOp();
295         if (nullptr != inpDefOp && visited.insert(inpDefOp).second)
296           queue.push(inpDefOp);
297       }
298     }
299   }
300 
301   clusters[shape] = std::move(cluster);
302 }
303 
304 // Returns whether `op` is a shape.with_shape, or all the users' of `op`
305 // eventually point to the shape operand of shape.with_shape ops
306 bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
307     Operation *op, Value prevOutput) {
308   if (onlyUsedByWithShapes.contains(op))
309     return true;
310 
311   if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
312     return withOp.getShape() == prevOutput;
313 
314   if (op->use_empty())
315     return false;
316 
317   for (Value oup : op->getResults())
318     for (Operation *user : oup.getUsers())
319       if (!calOnlyUsedByWithShapesRecursively(user, oup))
320         return false;
321 
322   onlyUsedByWithShapes.insert(op);
323   return true;
324 }
325 
326 } // namespace
327 
328 std::unique_ptr<OperationPass<ModuleOp>>
329 mlir::createOutlineShapeComputationPass() {
330   return std::make_unique<OutlineShapeComputationPass>();
331 }
332