xref: /llvm-project/mlir/lib/Transforms/ViewOpGraph.cpp (revision b00e0c167186d69e1e6bceda57c09b272bd6acfc)
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/ViewOpGraph.h"
10 
11 #include "mlir/Analysis/TopologicalSortUtils.h"
12 #include "mlir/IR/Block.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Support/IndentedOstream.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
19 #include <map>
20 #include <optional>
21 #include <utility>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_VIEWOPGRAPH
25 #include "mlir/Transforms/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 static const StringRef kLineStyleControlFlow = "dashed";
31 static const StringRef kLineStyleDataFlow = "solid";
32 static const StringRef kShapeNode = "ellipse";
33 static const StringRef kShapeNone = "plain";
34 
35 /// Return the size limits for eliding large attributes.
36 static int64_t getLargeAttributeSizeLimit() {
37   // Use the default from the printer flags if possible.
38   if (std::optional<int64_t> limit =
39           OpPrintingFlags().getLargeElementsAttrLimit())
40     return *limit;
41   return 16;
42 }
43 
44 /// Return all values printed onto a stream as a string.
45 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46   std::string buf;
47   llvm::raw_string_ostream os(buf);
48   func(os);
49   return os.str();
50 }
51 
52 /// Escape special characters such as '\n' and quotation marks.
53 static std::string escapeString(std::string str) {
54   return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
55 }
56 
57 /// Put quotation marks around a given string.
58 static std::string quoteString(const std::string &str) {
59   return "\"" + str + "\"";
60 }
61 
62 using AttributeMap = std::map<std::string, std::string>;
63 
64 namespace {
65 
66 /// This struct represents a node in the DOT language. Each node has an
67 /// identifier and an optional identifier for the cluster (subgraph) that
68 /// contains the node.
69 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
70 /// not between clusters. However, edges can be clipped to the boundary of a
71 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
72 /// cluster, an invisible "anchor" node is created.
73 struct Node {
74 public:
75   Node(int id = 0, std::optional<int> clusterId = std::nullopt)
76       : id(id), clusterId(clusterId) {}
77 
78   int id;
79   std::optional<int> clusterId;
80 };
81 
82 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
83 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
84 /// about the Graphviz DOT language.
85 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
86 public:
87   PrintOpPass(raw_ostream &os) : os(os) {}
88   PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
89 
90   void runOnOperation() override {
91     initColorMapping(*getOperation());
92     emitGraph([&]() {
93       processOperation(getOperation());
94       emitAllEdgeStmts();
95     });
96   }
97 
98   /// Create a CFG graph for a region. Used in `Region::viewGraph`.
99   void emitRegionCFG(Region &region) {
100     printControlFlowEdges = true;
101     printDataFlowEdges = false;
102     initColorMapping(region);
103     emitGraph([&]() { processRegion(region); });
104   }
105 
106 private:
107   /// Generate a color mapping that will color every operation with the same
108   /// name the same way. It'll interpolate the hue in the HSV color-space,
109   /// attempting to keep the contrast suitable for black text.
110   template <typename T>
111   void initColorMapping(T &irEntity) {
112     backgroundColors.clear();
113     SmallVector<Operation *> ops;
114     irEntity.walk([&](Operation *op) {
115       auto &entry = backgroundColors[op->getName()];
116       if (entry.first == 0)
117         ops.push_back(op);
118       ++entry.first;
119     });
120     for (auto indexedOps : llvm::enumerate(ops)) {
121       double hue = ((double)indexedOps.index()) / ops.size();
122       backgroundColors[indexedOps.value()->getName()].second =
123           std::to_string(hue) + " 1.0 1.0";
124     }
125   }
126 
127   /// Emit all edges. This function should be called after all nodes have been
128   /// emitted.
129   void emitAllEdgeStmts() {
130     if (printDataFlowEdges) {
131       for (const auto &[value, node, label] : dataFlowEdges) {
132         emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow);
133       }
134     }
135 
136     for (const std::string &edge : edges)
137       os << edge << ";\n";
138     edges.clear();
139   }
140 
141   /// Emit a cluster (subgraph). The specified builder generates the body of the
142   /// cluster. Return the anchor node of the cluster.
143   Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
144     int clusterId = ++counter;
145     os << "subgraph cluster_" << clusterId << " {\n";
146     os.indent();
147     // Emit invisible anchor node from/to which arrows can be drawn.
148     Node anchorNode = emitNodeStmt(" ", kShapeNone);
149     os << attrStmt("label", quoteString(escapeString(std::move(label))))
150        << ";\n";
151     builder();
152     os.unindent();
153     os << "}\n";
154     return Node(anchorNode.id, clusterId);
155   }
156 
157   /// Generate an attribute statement.
158   std::string attrStmt(const Twine &key, const Twine &value) {
159     return (key + " = " + value).str();
160   }
161 
162   /// Emit an attribute list.
163   void emitAttrList(raw_ostream &os, const AttributeMap &map) {
164     os << "[";
165     interleaveComma(map, os, [&](const auto &it) {
166       os << this->attrStmt(it.first, it.second);
167     });
168     os << "]";
169   }
170 
171   // Print an MLIR attribute to `os`. Large attributes are truncated.
172   void emitMlirAttr(raw_ostream &os, Attribute attr) {
173     // A value used to elide large container attribute.
174     int64_t largeAttrLimit = getLargeAttributeSizeLimit();
175 
176     // Always emit splat attributes.
177     if (isa<SplatElementsAttr>(attr)) {
178       attr.print(os);
179       return;
180     }
181 
182     // Elide "big" elements attributes.
183     auto elements = dyn_cast<ElementsAttr>(attr);
184     if (elements && elements.getNumElements() > largeAttrLimit) {
185       os << std::string(elements.getShapedType().getRank(), '[') << "..."
186          << std::string(elements.getShapedType().getRank(), ']') << " : "
187          << elements.getType();
188       return;
189     }
190 
191     auto array = dyn_cast<ArrayAttr>(attr);
192     if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
193       os << "[...]";
194       return;
195     }
196 
197     // Print all other attributes.
198     std::string buf;
199     llvm::raw_string_ostream ss(buf);
200     attr.print(ss);
201     os << truncateString(ss.str());
202   }
203 
204   /// Append an edge to the list of edges.
205   /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
206   void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
207     AttributeMap attrs;
208     attrs["style"] = style.str();
209     // Do not label edges that start/end at a cluster boundary. Such edges are
210     // clipped at the boundary, but labels are not. This can lead to labels
211     // floating around without any edge next to them.
212     if (!n1.clusterId && !n2.clusterId)
213       attrs["label"] = quoteString(escapeString(std::move(label)));
214     // Use `ltail` and `lhead` to draw edges between clusters.
215     if (n1.clusterId)
216       attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
217     if (n2.clusterId)
218       attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
219 
220     edges.push_back(strFromOs([&](raw_ostream &os) {
221       os << llvm::format("v%i -> v%i ", n1.id, n2.id);
222       emitAttrList(os, attrs);
223     }));
224   }
225 
226   /// Emit a graph. The specified builder generates the body of the graph.
227   void emitGraph(function_ref<void()> builder) {
228     os << "digraph G {\n";
229     os.indent();
230     // Edges between clusters are allowed only in compound mode.
231     os << attrStmt("compound", "true") << ";\n";
232     builder();
233     os.unindent();
234     os << "}\n";
235   }
236 
237   /// Emit a node statement.
238   Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
239                     StringRef background = "") {
240     int nodeId = ++counter;
241     AttributeMap attrs;
242     attrs["label"] = quoteString(escapeString(std::move(label)));
243     attrs["shape"] = shape.str();
244     if (!background.empty()) {
245       attrs["style"] = "filled";
246       attrs["fillcolor"] = ("\"" + background + "\"").str();
247     }
248     os << llvm::format("v%i ", nodeId);
249     emitAttrList(os, attrs);
250     os << ";\n";
251     return Node(nodeId);
252   }
253 
254   /// Generate a label for an operation.
255   std::string getLabel(Operation *op) {
256     return strFromOs([&](raw_ostream &os) {
257       // Print operation name and type.
258       os << op->getName();
259       if (printResultTypes) {
260         os << " : (";
261         std::string buf;
262         llvm::raw_string_ostream ss(buf);
263         interleaveComma(op->getResultTypes(), ss);
264         os << truncateString(ss.str()) << ")";
265       }
266 
267       // Print attributes.
268       if (printAttrs) {
269         os << "\n";
270         for (const NamedAttribute &attr : op->getAttrs()) {
271           os << '\n' << attr.getName().getValue() << ": ";
272           emitMlirAttr(os, attr.getValue());
273         }
274       }
275     });
276   }
277 
278   /// Generate a label for a block argument.
279   std::string getLabel(BlockArgument arg) {
280     return "arg" + std::to_string(arg.getArgNumber());
281   }
282 
283   /// Process a block. Emit a cluster and one node per block argument and
284   /// operation inside the cluster.
285   void processBlock(Block &block) {
286     emitClusterStmt([&]() {
287       for (BlockArgument &blockArg : block.getArguments())
288         valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
289 
290       // Emit a node for each operation.
291       std::optional<Node> prevNode;
292       for (Operation &op : block) {
293         Node nextNode = processOperation(&op);
294         if (printControlFlowEdges && prevNode)
295           emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
296                        kLineStyleControlFlow);
297         prevNode = nextNode;
298       }
299     });
300   }
301 
302   /// Process an operation. If the operation has regions, emit a cluster.
303   /// Otherwise, emit a node.
304   Node processOperation(Operation *op) {
305     Node node;
306     if (op->getNumRegions() > 0) {
307       // Emit cluster for op with regions.
308       node = emitClusterStmt(
309           [&]() {
310             for (Region &region : op->getRegions())
311               processRegion(region);
312           },
313           getLabel(op));
314     } else {
315       node = emitNodeStmt(getLabel(op), kShapeNode,
316                           backgroundColors[op->getName()].second);
317     }
318 
319     // Insert data flow edges originating from each operand.
320     if (printDataFlowEdges) {
321       unsigned numOperands = op->getNumOperands();
322       for (unsigned i = 0; i < numOperands; i++)
323         dataFlowEdges.push_back({op->getOperand(i), node,
324                                  numOperands == 1 ? "" : std::to_string(i)});
325     }
326 
327     for (Value result : op->getResults())
328       valueToNode[result] = node;
329 
330     return node;
331   }
332 
333   /// Process a region.
334   void processRegion(Region &region) {
335     for (Block &block : region.getBlocks())
336       processBlock(block);
337   }
338 
339   /// Truncate long strings.
340   std::string truncateString(std::string str) {
341     if (str.length() <= maxLabelLen)
342       return str;
343     return str.substr(0, maxLabelLen) + "...";
344   }
345 
346   /// Output stream to write DOT file to.
347   raw_indented_ostream os;
348   /// A list of edges. For simplicity, should be emitted after all nodes were
349   /// emitted.
350   std::vector<std::string> edges;
351   /// Mapping of SSA values to Graphviz nodes/clusters.
352   DenseMap<Value, Node> valueToNode;
353   /// Output for data flow edges is delayed until the end to handle cycles
354   std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
355   /// Counter for generating unique node/subgraph identifiers.
356   int counter = 0;
357 
358   DenseMap<OperationName, std::pair<int, std::string>> backgroundColors;
359 };
360 
361 } // namespace
362 
363 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
364   return std::make_unique<PrintOpPass>(os);
365 }
366 
367 /// Generate a CFG for a region and show it in a window.
368 static void llvmViewGraph(Region &region, const Twine &name) {
369   int fd;
370   std::string filename = llvm::createGraphFilename(name.str(), fd);
371   {
372     llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
373     if (fd == -1) {
374       llvm::errs() << "error opening file '" << filename << "' for writing\n";
375       return;
376     }
377     PrintOpPass pass(os);
378     pass.emitRegionCFG(region);
379   }
380   llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
381 }
382 
383 void mlir::Region::viewGraph(const Twine &regionName) {
384   llvmViewGraph(*this, regionName);
385 }
386 
387 void mlir::Region::viewGraph() { viewGraph("region"); }
388