xref: /llvm-project/mlir/lib/Transforms/ViewOpGraph.cpp (revision 0c7890c844fdc7adb6d0cf58403e3fdd7407915d)
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/ViewOpGraph.h"
10 #include "PassDetail.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/Support/IndentedOstream.h"
14 #include "llvm/Support/Format.h"
15 #include "llvm/Support/GraphWriter.h"
16 
17 using namespace mlir;
18 
19 static const StringRef kLineStyleControlFlow = "dashed";
20 static const StringRef kLineStyleDataFlow = "solid";
21 static const StringRef kShapeNode = "ellipse";
22 static const StringRef kShapeNone = "plain";
23 
24 /// Return the size limits for eliding large attributes.
25 static int64_t getLargeAttributeSizeLimit() {
26   // Use the default from the printer flags if possible.
27   if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit())
28     return *limit;
29   return 16;
30 }
31 
32 /// Return all values printed onto a stream as a string.
33 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
34   std::string buf;
35   llvm::raw_string_ostream os(buf);
36   func(os);
37   return os.str();
38 }
39 
40 /// Escape special characters such as '\n' and quotation marks.
41 static std::string escapeString(std::string str) {
42   return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
43 }
44 
45 /// Put quotation marks around a given string.
46 static std::string quoteString(std::string str) { return "\"" + str + "\""; }
47 
48 using AttributeMap = llvm::StringMap<std::string>;
49 
50 namespace {
51 
52 /// This struct represents a node in the DOT language. Each node has an
53 /// identifier and an optional identifier for the cluster (subgraph) that
54 /// contains the node.
55 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
56 /// not between clusters. However, edges can be clipped to the boundary of a
57 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
58 /// cluster, an invisible "anchor" node is created.
59 struct Node {
60 public:
61   Node(int id = 0, Optional<int> clusterId = llvm::None)
62       : id(id), clusterId(clusterId) {}
63 
64   int id;
65   Optional<int> clusterId;
66 };
67 
68 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
69 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
70 /// about the Graphviz DOT language.
71 class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
72 public:
73   PrintOpPass(raw_ostream &os) : os(os) {}
74   PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
75 
76   void runOnOperation() override {
77     emitGraph([&]() {
78       processOperation(getOperation());
79       emitAllEdgeStmts();
80     });
81   }
82 
83   /// Create a CFG graph for a region. Used in `Region::viewGraph`.
84   void emitRegionCFG(Region &region) {
85     printControlFlowEdges = true;
86     printDataFlowEdges = false;
87     emitGraph([&]() { processRegion(region); });
88   }
89 
90 private:
91   /// Emit all edges. This function should be called after all nodes have been
92   /// emitted.
93   void emitAllEdgeStmts() {
94     for (const std::string &edge : edges)
95       os << edge << ";\n";
96     edges.clear();
97   }
98 
99   /// Emit a cluster (subgraph). The specified builder generates the body of the
100   /// cluster. Return the anchor node of the cluster.
101   Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
102     int clusterId = ++counter;
103     os << "subgraph cluster_" << clusterId << " {\n";
104     os.indent();
105     // Emit invisible anchor node from/to which arrows can be drawn.
106     Node anchorNode = emitNodeStmt(" ", kShapeNone);
107     os << attrStmt("label", quoteString(escapeString(label))) << ";\n";
108     builder();
109     os.unindent();
110     os << "}\n";
111     return Node(anchorNode.id, clusterId);
112   }
113 
114   /// Generate an attribute statement.
115   std::string attrStmt(const Twine &key, const Twine &value) {
116     return (key + " = " + value).str();
117   }
118 
119   /// Emit an attribute list.
120   void emitAttrList(raw_ostream &os, const AttributeMap &map) {
121     os << "[";
122     interleaveComma(map, os, [&](const auto &it) {
123       os << this->attrStmt(it.getKey(), it.getValue());
124     });
125     os << "]";
126   }
127 
128   // Print an MLIR attribute to `os`. Large attributes are truncated.
129   void emitMlirAttr(raw_ostream &os, Attribute attr) {
130     // A value used to elide large container attribute.
131     int64_t largeAttrLimit = getLargeAttributeSizeLimit();
132 
133     // Always emit splat attributes.
134     if (attr.isa<SplatElementsAttr>()) {
135       attr.print(os);
136       return;
137     }
138 
139     // Elide "big" elements attributes.
140     auto elements = attr.dyn_cast<ElementsAttr>();
141     if (elements && elements.getNumElements() > largeAttrLimit) {
142       os << std::string(elements.getType().getRank(), '[') << "..."
143          << std::string(elements.getType().getRank(), ']') << " : "
144          << elements.getType();
145       return;
146     }
147 
148     auto array = attr.dyn_cast<ArrayAttr>();
149     if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
150       os << "[...]";
151       return;
152     }
153 
154     // Print all other attributes.
155     std::string buf;
156     llvm::raw_string_ostream ss(buf);
157     attr.print(ss);
158     os << truncateString(ss.str());
159   }
160 
161   /// Append an edge to the list of edges.
162   /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
163   void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
164     AttributeMap attrs;
165     attrs["style"] = style.str();
166     // Do not label edges that start/end at a cluster boundary. Such edges are
167     // clipped at the boundary, but labels are not. This can lead to labels
168     // floating around without any edge next to them.
169     if (!n1.clusterId && !n2.clusterId)
170       attrs["label"] = quoteString(escapeString(label));
171     // Use `ltail` and `lhead` to draw edges between clusters.
172     if (n1.clusterId)
173       attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
174     if (n2.clusterId)
175       attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
176 
177     edges.push_back(strFromOs([&](raw_ostream &os) {
178       os << llvm::format("v%i -> v%i ", n1.id, n2.id);
179       emitAttrList(os, attrs);
180     }));
181   }
182 
183   /// Emit a graph. The specified builder generates the body of the graph.
184   void emitGraph(function_ref<void()> builder) {
185     os << "digraph G {\n";
186     os.indent();
187     // Edges between clusters are allowed only in compound mode.
188     os << attrStmt("compound", "true") << ";\n";
189     builder();
190     os.unindent();
191     os << "}\n";
192   }
193 
194   /// Emit a node statement.
195   Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
196     int nodeId = ++counter;
197     AttributeMap attrs;
198     attrs["label"] = quoteString(escapeString(label));
199     attrs["shape"] = shape.str();
200     os << llvm::format("v%i ", nodeId);
201     emitAttrList(os, attrs);
202     os << ";\n";
203     return Node(nodeId);
204   }
205 
206   /// Generate a label for an operation.
207   std::string getLabel(Operation *op) {
208     return strFromOs([&](raw_ostream &os) {
209       // Print operation name and type.
210       os << op->getName();
211       if (printResultTypes) {
212         os << " : (";
213         std::string buf;
214         llvm::raw_string_ostream ss(buf);
215         interleaveComma(op->getResultTypes(), ss);
216         os << truncateString(ss.str()) << ")";
217         os << ")";
218       }
219 
220       // Print attributes.
221       if (printAttrs) {
222         os << "\n";
223         for (const NamedAttribute &attr : op->getAttrs()) {
224           os << '\n' << attr.getName().getValue() << ": ";
225           emitMlirAttr(os, attr.getValue());
226         }
227       }
228     });
229   }
230 
231   /// Generate a label for a block argument.
232   std::string getLabel(BlockArgument arg) {
233     return "arg" + std::to_string(arg.getArgNumber());
234   }
235 
236   /// Process a block. Emit a cluster and one node per block argument and
237   /// operation inside the cluster.
238   void processBlock(Block &block) {
239     emitClusterStmt([&]() {
240       for (BlockArgument &blockArg : block.getArguments())
241         valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
242 
243       // Emit a node for each operation.
244       Optional<Node> prevNode;
245       for (Operation &op : block) {
246         Node nextNode = processOperation(&op);
247         if (printControlFlowEdges && prevNode)
248           emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
249                        kLineStyleControlFlow);
250         prevNode = nextNode;
251       }
252     });
253   }
254 
255   /// Process an operation. If the operation has regions, emit a cluster.
256   /// Otherwise, emit a node.
257   Node processOperation(Operation *op) {
258     Node node;
259     if (op->getNumRegions() > 0) {
260       // Emit cluster for op with regions.
261       node = emitClusterStmt(
262           [&]() {
263             for (Region &region : op->getRegions())
264               processRegion(region);
265           },
266           getLabel(op));
267     } else {
268       node = emitNodeStmt(getLabel(op));
269     }
270 
271     // Insert data flow edges originating from each operand.
272     if (printDataFlowEdges) {
273       unsigned numOperands = op->getNumOperands();
274       for (unsigned i = 0; i < numOperands; i++)
275         emitEdgeStmt(valueToNode[op->getOperand(i)], node,
276                      /*label=*/numOperands == 1 ? "" : std::to_string(i),
277                      kLineStyleDataFlow);
278     }
279 
280     for (Value result : op->getResults())
281       valueToNode[result] = node;
282 
283     return node;
284   }
285 
286   /// Process a region.
287   void processRegion(Region &region) {
288     for (Block &block : region.getBlocks())
289       processBlock(block);
290   }
291 
292   /// Truncate long strings.
293   std::string truncateString(std::string str) {
294     if (str.length() <= maxLabelLen)
295       return str;
296     return str.substr(0, maxLabelLen) + "...";
297   }
298 
299   /// Output stream to write DOT file to.
300   raw_indented_ostream os;
301   /// A list of edges. For simplicity, should be emitted after all nodes were
302   /// emitted.
303   std::vector<std::string> edges;
304   /// Mapping of SSA values to Graphviz nodes/clusters.
305   DenseMap<Value, Node> valueToNode;
306   /// Counter for generating unique node/subgraph identifiers.
307   int counter = 0;
308 };
309 
310 } // namespace
311 
312 std::unique_ptr<Pass>
313 mlir::createPrintOpGraphPass(raw_ostream &os) {
314   return std::make_unique<PrintOpPass>(os);
315 }
316 
317 /// Generate a CFG for a region and show it in a window.
318 static void llvmViewGraph(Region &region, const Twine &name) {
319   int fd;
320   std::string filename = llvm::createGraphFilename(name.str(), fd);
321   {
322     llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
323     if (fd == -1) {
324       llvm::errs() << "error opening file '" << filename << "' for writing\n";
325       return;
326     }
327     PrintOpPass pass(os);
328     pass.emitRegionCFG(region);
329   }
330   llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
331 }
332 
333 void mlir::Region::viewGraph(const Twine &regionName) {
334   llvmViewGraph(*this, regionName);
335 }
336 
337 void mlir::Region::viewGraph() { viewGraph("region"); }
338