Lines Matching full:shape

10 #include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 #include "mlir/Dialect/Shape/Transforms/Passes.h"
27 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
30 #define DEBUG_TYPE "outline-shape-computation"
63 // Create a shape.func representing the shape computation for `shape`.
64 std::pair<shape::FuncOp, SmallVector<Value>>
66 Value shape, StringRef fnName, Location loc) {
70 ? b.getFunctionType(shape.getType(), shape.getType())
71 : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
72 shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
77 bvm.map(shape, fnOp.getArgument(0));
86 fnReturns.push_back(bvm.lookupOrDefault(shape));
88 b.create<shape::ReturnOp>(loc, fnReturns);
94 // when creating shape.func op.
101 Value shape = it.first;
104 op2Shapes[cOp].push_back(shape);
114 for (Value shape : it->second)
115 orderedClusters[shape].push_back(cOp);
123 const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
126 DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
127 func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
132 // Construct a shape function
133 for (shape::WithOp withOp : allWithOps) {
135 Value shape = withOp.getShape();
140 const SmallVector<Operation *, 8> &cluster = clusters[shape];
141 shape::ShapeMappingValue shapeMappingValue;
142 auto it = dynShape2ShapeFunc.find(shape);
148 auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
150 shape::FuncOp shapeFuncOp = pair.first;
159 dynShape2ShapeFunc[shape] = shapeMappingValue;
173 void getClusterFromValue(Value shape,
177 constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
189 rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
190 rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
199 DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc;
200 auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
204 markAnalysesPreserved<shape::ShapeMappingAnalysis>();
224 // collect all the shape.with_shape ops.
225 std::vector<shape::WithOp> allWithOps;
226 funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
233 for (shape::WithOp withOp : allWithOps) {
237 if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
239 // %1 = shape.with_shape %arg1, %0
240 // %2 = shape.value_of %1
241 // because shape.value doesn't care the shape, the shape.with_shape is
245 // If type of %arg1 has different type like !shape.value_shape,
247 // %2 = shape.value_of %arg1
264 const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
266 for (shape::WithOp withOp : allWithOps) {
267 Value shape = withOp.getShape();
268 if (clusters.count(shape) == 0)
269 getClusterFromValue(shape, clusters);
274 // The output of a cluster is the `shape`, and the inputs are the outputs of
277 Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
283 // defOp == nullptr means shape is the argument of the func op
284 if (Operation *defOp = shape.getDefiningOp()) {
301 clusters[shape] = std::move(cluster);
304 // Returns whether `op` is a shape.with_shape, or all the users' of `op`
305 // eventually point to the shape operand of shape.with_shape ops
311 if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))