xref: /llvm-project/mlir/lib/Transforms/ViewOpGraph.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/ViewOpGraph.h"
10 #include "PassDetail.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Support/IndentedOstream.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/GraphWriter.h"
17 #include <utility>
18 
19 using namespace mlir;
20 
21 static const StringRef kLineStyleControlFlow = "dashed";
22 static const StringRef kLineStyleDataFlow = "solid";
23 static const StringRef kShapeNode = "ellipse";
24 static const StringRef kShapeNone = "plain";
25 
26 /// Return the size limits for eliding large attributes.
27 static int64_t getLargeAttributeSizeLimit() {
28   // Use the default from the printer flags if possible.
29   if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit())
30     return *limit;
31   return 16;
32 }
33 
34 /// Return all values printed onto a stream as a string.
35 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
36   std::string buf;
37   llvm::raw_string_ostream os(buf);
38   func(os);
39   return os.str();
40 }
41 
42 /// Escape special characters such as '\n' and quotation marks.
43 static std::string escapeString(std::string str) {
44   return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
45 }
46 
47 /// Put quotation marks around a given string.
48 static std::string quoteString(const std::string &str) {
49   return "\"" + str + "\"";
50 }
51 
52 using AttributeMap = llvm::StringMap<std::string>;
53 
54 namespace {
55 
56 /// This struct represents a node in the DOT language. Each node has an
57 /// identifier and an optional identifier for the cluster (subgraph) that
58 /// contains the node.
59 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
60 /// not between clusters. However, edges can be clipped to the boundary of a
61 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
62 /// cluster, an invisible "anchor" node is created.
63 struct Node {
64 public:
65   Node(int id = 0, Optional<int> clusterId = llvm::None)
66       : id(id), clusterId(clusterId) {}
67 
68   int id;
69   Optional<int> clusterId;
70 };
71 
72 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
73 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
74 /// about the Graphviz DOT language.
75 class PrintOpPass : public ViewOpGraphBase<PrintOpPass> {
76 public:
77   PrintOpPass(raw_ostream &os) : os(os) {}
78   PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
79 
80   void runOnOperation() override {
81     emitGraph([&]() {
82       processOperation(getOperation());
83       emitAllEdgeStmts();
84     });
85   }
86 
87   /// Create a CFG graph for a region. Used in `Region::viewGraph`.
88   void emitRegionCFG(Region &region) {
89     printControlFlowEdges = true;
90     printDataFlowEdges = false;
91     emitGraph([&]() { processRegion(region); });
92   }
93 
94 private:
95   /// Emit all edges. This function should be called after all nodes have been
96   /// emitted.
97   void emitAllEdgeStmts() {
98     for (const std::string &edge : edges)
99       os << edge << ";\n";
100     edges.clear();
101   }
102 
103   /// Emit a cluster (subgraph). The specified builder generates the body of the
104   /// cluster. Return the anchor node of the cluster.
105   Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
106     int clusterId = ++counter;
107     os << "subgraph cluster_" << clusterId << " {\n";
108     os.indent();
109     // Emit invisible anchor node from/to which arrows can be drawn.
110     Node anchorNode = emitNodeStmt(" ", kShapeNone);
111     os << attrStmt("label", quoteString(escapeString(std::move(label))))
112        << ";\n";
113     builder();
114     os.unindent();
115     os << "}\n";
116     return Node(anchorNode.id, clusterId);
117   }
118 
119   /// Generate an attribute statement.
120   std::string attrStmt(const Twine &key, const Twine &value) {
121     return (key + " = " + value).str();
122   }
123 
124   /// Emit an attribute list.
125   void emitAttrList(raw_ostream &os, const AttributeMap &map) {
126     os << "[";
127     interleaveComma(map, os, [&](const auto &it) {
128       os << this->attrStmt(it.getKey(), it.getValue());
129     });
130     os << "]";
131   }
132 
133   // Print an MLIR attribute to `os`. Large attributes are truncated.
134   void emitMlirAttr(raw_ostream &os, Attribute attr) {
135     // A value used to elide large container attribute.
136     int64_t largeAttrLimit = getLargeAttributeSizeLimit();
137 
138     // Always emit splat attributes.
139     if (attr.isa<SplatElementsAttr>()) {
140       attr.print(os);
141       return;
142     }
143 
144     // Elide "big" elements attributes.
145     auto elements = attr.dyn_cast<ElementsAttr>();
146     if (elements && elements.getNumElements() > largeAttrLimit) {
147       os << std::string(elements.getType().getRank(), '[') << "..."
148          << std::string(elements.getType().getRank(), ']') << " : "
149          << elements.getType();
150       return;
151     }
152 
153     auto array = attr.dyn_cast<ArrayAttr>();
154     if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
155       os << "[...]";
156       return;
157     }
158 
159     // Print all other attributes.
160     std::string buf;
161     llvm::raw_string_ostream ss(buf);
162     attr.print(ss);
163     os << truncateString(ss.str());
164   }
165 
166   /// Append an edge to the list of edges.
167   /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
168   void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
169     AttributeMap attrs;
170     attrs["style"] = style.str();
171     // Do not label edges that start/end at a cluster boundary. Such edges are
172     // clipped at the boundary, but labels are not. This can lead to labels
173     // floating around without any edge next to them.
174     if (!n1.clusterId && !n2.clusterId)
175       attrs["label"] = quoteString(escapeString(std::move(label)));
176     // Use `ltail` and `lhead` to draw edges between clusters.
177     if (n1.clusterId)
178       attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
179     if (n2.clusterId)
180       attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
181 
182     edges.push_back(strFromOs([&](raw_ostream &os) {
183       os << llvm::format("v%i -> v%i ", n1.id, n2.id);
184       emitAttrList(os, attrs);
185     }));
186   }
187 
188   /// Emit a graph. The specified builder generates the body of the graph.
189   void emitGraph(function_ref<void()> builder) {
190     os << "digraph G {\n";
191     os.indent();
192     // Edges between clusters are allowed only in compound mode.
193     os << attrStmt("compound", "true") << ";\n";
194     builder();
195     os.unindent();
196     os << "}\n";
197   }
198 
199   /// Emit a node statement.
200   Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
201     int nodeId = ++counter;
202     AttributeMap attrs;
203     attrs["label"] = quoteString(escapeString(std::move(label)));
204     attrs["shape"] = shape.str();
205     os << llvm::format("v%i ", nodeId);
206     emitAttrList(os, attrs);
207     os << ";\n";
208     return Node(nodeId);
209   }
210 
211   /// Generate a label for an operation.
212   std::string getLabel(Operation *op) {
213     return strFromOs([&](raw_ostream &os) {
214       // Print operation name and type.
215       os << op->getName();
216       if (printResultTypes) {
217         os << " : (";
218         std::string buf;
219         llvm::raw_string_ostream ss(buf);
220         interleaveComma(op->getResultTypes(), ss);
221         os << truncateString(ss.str()) << ")";
222         os << ")";
223       }
224 
225       // Print attributes.
226       if (printAttrs) {
227         os << "\n";
228         for (const NamedAttribute &attr : op->getAttrs()) {
229           os << '\n' << attr.getName().getValue() << ": ";
230           emitMlirAttr(os, attr.getValue());
231         }
232       }
233     });
234   }
235 
236   /// Generate a label for a block argument.
237   std::string getLabel(BlockArgument arg) {
238     return "arg" + std::to_string(arg.getArgNumber());
239   }
240 
241   /// Process a block. Emit a cluster and one node per block argument and
242   /// operation inside the cluster.
243   void processBlock(Block &block) {
244     emitClusterStmt([&]() {
245       for (BlockArgument &blockArg : block.getArguments())
246         valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
247 
248       // Emit a node for each operation.
249       Optional<Node> prevNode;
250       for (Operation &op : block) {
251         Node nextNode = processOperation(&op);
252         if (printControlFlowEdges && prevNode)
253           emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
254                        kLineStyleControlFlow);
255         prevNode = nextNode;
256       }
257     });
258   }
259 
260   /// Process an operation. If the operation has regions, emit a cluster.
261   /// Otherwise, emit a node.
262   Node processOperation(Operation *op) {
263     Node node;
264     if (op->getNumRegions() > 0) {
265       // Emit cluster for op with regions.
266       node = emitClusterStmt(
267           [&]() {
268             for (Region &region : op->getRegions())
269               processRegion(region);
270           },
271           getLabel(op));
272     } else {
273       node = emitNodeStmt(getLabel(op));
274     }
275 
276     // Insert data flow edges originating from each operand.
277     if (printDataFlowEdges) {
278       unsigned numOperands = op->getNumOperands();
279       for (unsigned i = 0; i < numOperands; i++)
280         emitEdgeStmt(valueToNode[op->getOperand(i)], node,
281                      /*label=*/numOperands == 1 ? "" : std::to_string(i),
282                      kLineStyleDataFlow);
283     }
284 
285     for (Value result : op->getResults())
286       valueToNode[result] = node;
287 
288     return node;
289   }
290 
291   /// Process a region.
292   void processRegion(Region &region) {
293     for (Block &block : region.getBlocks())
294       processBlock(block);
295   }
296 
297   /// Truncate long strings.
298   std::string truncateString(std::string str) {
299     if (str.length() <= maxLabelLen)
300       return str;
301     return str.substr(0, maxLabelLen) + "...";
302   }
303 
304   /// Output stream to write DOT file to.
305   raw_indented_ostream os;
306   /// A list of edges. For simplicity, should be emitted after all nodes were
307   /// emitted.
308   std::vector<std::string> edges;
309   /// Mapping of SSA values to Graphviz nodes/clusters.
310   DenseMap<Value, Node> valueToNode;
311   /// Counter for generating unique node/subgraph identifiers.
312   int counter = 0;
313 };
314 
315 } // namespace
316 
317 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
318   return std::make_unique<PrintOpPass>(os);
319 }
320 
321 /// Generate a CFG for a region and show it in a window.
322 static void llvmViewGraph(Region &region, const Twine &name) {
323   int fd;
324   std::string filename = llvm::createGraphFilename(name.str(), fd);
325   {
326     llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
327     if (fd == -1) {
328       llvm::errs() << "error opening file '" << filename << "' for writing\n";
329       return;
330     }
331     PrintOpPass pass(os);
332     pass.emitRegionCFG(region);
333   }
334   llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
335 }
336 
337 void mlir::Region::viewGraph(const Twine &regionName) {
338   llvmViewGraph(*this, regionName);
339 }
340 
341 void mlir::Region::viewGraph() { viewGraph("region"); }
342