xref: /llvm-project/mlir/lib/Transforms/ViewOpGraph.cpp (revision 618caf6fcce17d6f1ca752604138fb234e781ff3)
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/ViewOpGraph.h"
10 
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Support/IndentedOstream.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/GraphWriter.h"
18 #include <map>
19 #include <optional>
20 #include <utility>
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_VIEWOPGRAPH
24 #include "mlir/Transforms/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 static const StringRef kLineStyleControlFlow = "dashed";
30 static const StringRef kLineStyleDataFlow = "solid";
31 static const StringRef kShapeNode = "ellipse";
32 static const StringRef kShapeNone = "plain";
33 
34 /// Return the size limits for eliding large attributes.
35 static int64_t getLargeAttributeSizeLimit() {
36   // Use the default from the printer flags if possible.
37   if (std::optional<int64_t> limit =
38           OpPrintingFlags().getLargeElementsAttrLimit())
39     return *limit;
40   return 16;
41 }
42 
43 /// Return all values printed onto a stream as a string.
44 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
45   std::string buf;
46   llvm::raw_string_ostream os(buf);
47   func(os);
48   return os.str();
49 }
50 
51 /// Escape special characters such as '\n' and quotation marks.
52 static std::string escapeString(std::string str) {
53   return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
54 }
55 
56 /// Put quotation marks around a given string.
57 static std::string quoteString(const std::string &str) {
58   return "\"" + str + "\"";
59 }
60 
61 using AttributeMap = std::map<std::string, std::string>;
62 
63 namespace {
64 
65 /// This struct represents a node in the DOT language. Each node has an
66 /// identifier and an optional identifier for the cluster (subgraph) that
67 /// contains the node.
68 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
69 /// not between clusters. However, edges can be clipped to the boundary of a
70 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
71 /// cluster, an invisible "anchor" node is created.
72 struct Node {
73 public:
74   Node(int id = 0, std::optional<int> clusterId = std::nullopt)
75       : id(id), clusterId(clusterId) {}
76 
77   int id;
78   std::optional<int> clusterId;
79 };
80 
81 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
82 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
83 /// about the Graphviz DOT language.
84 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
85 public:
86   PrintOpPass(raw_ostream &os) : os(os) {}
87   PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
88 
89   void runOnOperation() override {
90     emitGraph([&]() {
91       processOperation(getOperation());
92       emitAllEdgeStmts();
93     });
94   }
95 
96   /// Create a CFG graph for a region. Used in `Region::viewGraph`.
97   void emitRegionCFG(Region &region) {
98     printControlFlowEdges = true;
99     printDataFlowEdges = false;
100     emitGraph([&]() { processRegion(region); });
101   }
102 
103 private:
104   /// Emit all edges. This function should be called after all nodes have been
105   /// emitted.
106   void emitAllEdgeStmts() {
107     for (const std::string &edge : edges)
108       os << edge << ";\n";
109     edges.clear();
110   }
111 
112   /// Emit a cluster (subgraph). The specified builder generates the body of the
113   /// cluster. Return the anchor node of the cluster.
114   Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
115     int clusterId = ++counter;
116     os << "subgraph cluster_" << clusterId << " {\n";
117     os.indent();
118     // Emit invisible anchor node from/to which arrows can be drawn.
119     Node anchorNode = emitNodeStmt(" ", kShapeNone);
120     os << attrStmt("label", quoteString(escapeString(std::move(label))))
121        << ";\n";
122     builder();
123     os.unindent();
124     os << "}\n";
125     return Node(anchorNode.id, clusterId);
126   }
127 
128   /// Generate an attribute statement.
129   std::string attrStmt(const Twine &key, const Twine &value) {
130     return (key + " = " + value).str();
131   }
132 
133   /// Emit an attribute list.
134   void emitAttrList(raw_ostream &os, const AttributeMap &map) {
135     os << "[";
136     interleaveComma(map, os, [&](const auto &it) {
137       os << this->attrStmt(it.first, it.second);
138     });
139     os << "]";
140   }
141 
142   // Print an MLIR attribute to `os`. Large attributes are truncated.
143   void emitMlirAttr(raw_ostream &os, Attribute attr) {
144     // A value used to elide large container attribute.
145     int64_t largeAttrLimit = getLargeAttributeSizeLimit();
146 
147     // Always emit splat attributes.
148     if (attr.isa<SplatElementsAttr>()) {
149       attr.print(os);
150       return;
151     }
152 
153     // Elide "big" elements attributes.
154     auto elements = attr.dyn_cast<ElementsAttr>();
155     if (elements && elements.getNumElements() > largeAttrLimit) {
156       os << std::string(elements.getType().getRank(), '[') << "..."
157          << std::string(elements.getType().getRank(), ']') << " : "
158          << elements.getType();
159       return;
160     }
161 
162     auto array = attr.dyn_cast<ArrayAttr>();
163     if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
164       os << "[...]";
165       return;
166     }
167 
168     // Print all other attributes.
169     std::string buf;
170     llvm::raw_string_ostream ss(buf);
171     attr.print(ss);
172     os << truncateString(ss.str());
173   }
174 
175   /// Append an edge to the list of edges.
176   /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
177   void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
178     AttributeMap attrs;
179     attrs["style"] = style.str();
180     // Do not label edges that start/end at a cluster boundary. Such edges are
181     // clipped at the boundary, but labels are not. This can lead to labels
182     // floating around without any edge next to them.
183     if (!n1.clusterId && !n2.clusterId)
184       attrs["label"] = quoteString(escapeString(std::move(label)));
185     // Use `ltail` and `lhead` to draw edges between clusters.
186     if (n1.clusterId)
187       attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
188     if (n2.clusterId)
189       attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
190 
191     edges.push_back(strFromOs([&](raw_ostream &os) {
192       os << llvm::format("v%i -> v%i ", n1.id, n2.id);
193       emitAttrList(os, attrs);
194     }));
195   }
196 
197   /// Emit a graph. The specified builder generates the body of the graph.
198   void emitGraph(function_ref<void()> builder) {
199     os << "digraph G {\n";
200     os.indent();
201     // Edges between clusters are allowed only in compound mode.
202     os << attrStmt("compound", "true") << ";\n";
203     builder();
204     os.unindent();
205     os << "}\n";
206   }
207 
208   /// Emit a node statement.
209   Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
210     int nodeId = ++counter;
211     AttributeMap attrs;
212     attrs["label"] = quoteString(escapeString(std::move(label)));
213     attrs["shape"] = shape.str();
214     os << llvm::format("v%i ", nodeId);
215     emitAttrList(os, attrs);
216     os << ";\n";
217     return Node(nodeId);
218   }
219 
220   /// Generate a label for an operation.
221   std::string getLabel(Operation *op) {
222     return strFromOs([&](raw_ostream &os) {
223       // Print operation name and type.
224       os << op->getName();
225       if (printResultTypes) {
226         os << " : (";
227         std::string buf;
228         llvm::raw_string_ostream ss(buf);
229         interleaveComma(op->getResultTypes(), ss);
230         os << truncateString(ss.str()) << ")";
231         os << ")";
232       }
233 
234       // Print attributes.
235       if (printAttrs) {
236         os << "\n";
237         for (const NamedAttribute &attr : op->getAttrs()) {
238           os << '\n' << attr.getName().getValue() << ": ";
239           emitMlirAttr(os, attr.getValue());
240         }
241       }
242     });
243   }
244 
245   /// Generate a label for a block argument.
246   std::string getLabel(BlockArgument arg) {
247     return "arg" + std::to_string(arg.getArgNumber());
248   }
249 
250   /// Process a block. Emit a cluster and one node per block argument and
251   /// operation inside the cluster.
252   void processBlock(Block &block) {
253     emitClusterStmt([&]() {
254       for (BlockArgument &blockArg : block.getArguments())
255         valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
256 
257       // Emit a node for each operation.
258       std::optional<Node> prevNode;
259       for (Operation &op : block) {
260         Node nextNode = processOperation(&op);
261         if (printControlFlowEdges && prevNode)
262           emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
263                        kLineStyleControlFlow);
264         prevNode = nextNode;
265       }
266     });
267   }
268 
269   /// Process an operation. If the operation has regions, emit a cluster.
270   /// Otherwise, emit a node.
271   Node processOperation(Operation *op) {
272     Node node;
273     if (op->getNumRegions() > 0) {
274       // Emit cluster for op with regions.
275       node = emitClusterStmt(
276           [&]() {
277             for (Region &region : op->getRegions())
278               processRegion(region);
279           },
280           getLabel(op));
281     } else {
282       node = emitNodeStmt(getLabel(op));
283     }
284 
285     // Insert data flow edges originating from each operand.
286     if (printDataFlowEdges) {
287       unsigned numOperands = op->getNumOperands();
288       for (unsigned i = 0; i < numOperands; i++)
289         emitEdgeStmt(valueToNode[op->getOperand(i)], node,
290                      /*label=*/numOperands == 1 ? "" : std::to_string(i),
291                      kLineStyleDataFlow);
292     }
293 
294     for (Value result : op->getResults())
295       valueToNode[result] = node;
296 
297     return node;
298   }
299 
300   /// Process a region.
301   void processRegion(Region &region) {
302     for (Block &block : region.getBlocks())
303       processBlock(block);
304   }
305 
306   /// Truncate long strings.
307   std::string truncateString(std::string str) {
308     if (str.length() <= maxLabelLen)
309       return str;
310     return str.substr(0, maxLabelLen) + "...";
311   }
312 
313   /// Output stream to write DOT file to.
314   raw_indented_ostream os;
315   /// A list of edges. For simplicity, should be emitted after all nodes were
316   /// emitted.
317   std::vector<std::string> edges;
318   /// Mapping of SSA values to Graphviz nodes/clusters.
319   DenseMap<Value, Node> valueToNode;
320   /// Counter for generating unique node/subgraph identifiers.
321   int counter = 0;
322 };
323 
324 } // namespace
325 
326 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
327   return std::make_unique<PrintOpPass>(os);
328 }
329 
330 /// Generate a CFG for a region and show it in a window.
331 static void llvmViewGraph(Region &region, const Twine &name) {
332   int fd;
333   std::string filename = llvm::createGraphFilename(name.str(), fd);
334   {
335     llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
336     if (fd == -1) {
337       llvm::errs() << "error opening file '" << filename << "' for writing\n";
338       return;
339     }
340     PrintOpPass pass(os);
341     pass.emitRegionCFG(region);
342   }
343   llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
344 }
345 
346 void mlir::Region::viewGraph(const Twine &regionName) {
347   llvmViewGraph(*this, regionName);
348 }
349 
350 void mlir::Region::viewGraph() { viewGraph("region"); }
351