1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/ViewOpGraph.h" 10 11 #include "mlir/IR/Block.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Support/IndentedOstream.h" 16 #include "llvm/Support/Format.h" 17 #include "llvm/Support/GraphWriter.h" 18 #include <map> 19 #include <optional> 20 #include <utility> 21 22 namespace mlir { 23 #define GEN_PASS_DEF_VIEWOPGRAPH 24 #include "mlir/Transforms/Passes.h.inc" 25 } // namespace mlir 26 27 using namespace mlir; 28 29 static const StringRef kLineStyleControlFlow = "dashed"; 30 static const StringRef kLineStyleDataFlow = "solid"; 31 static const StringRef kShapeNode = "ellipse"; 32 static const StringRef kShapeNone = "plain"; 33 34 /// Return the size limits for eliding large attributes. 35 static int64_t getLargeAttributeSizeLimit() { 36 // Use the default from the printer flags if possible. 37 if (std::optional<int64_t> limit = 38 OpPrintingFlags().getLargeElementsAttrLimit()) 39 return *limit; 40 return 16; 41 } 42 43 /// Return all values printed onto a stream as a string. 44 static std::string strFromOs(function_ref<void(raw_ostream &)> func) { 45 std::string buf; 46 llvm::raw_string_ostream os(buf); 47 func(os); 48 return os.str(); 49 } 50 51 /// Escape special characters such as '\n' and quotation marks. 52 static std::string escapeString(std::string str) { 53 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); }); 54 } 55 56 /// Put quotation marks around a given string. 57 static std::string quoteString(const std::string &str) { 58 return "\"" + str + "\""; 59 } 60 61 using AttributeMap = std::map<std::string, std::string>; 62 63 namespace { 64 65 /// This struct represents a node in the DOT language. Each node has an 66 /// identifier and an optional identifier for the cluster (subgraph) that 67 /// contains the node. 68 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but 69 /// not between clusters. However, edges can be clipped to the boundary of a 70 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new 71 /// cluster, an invisible "anchor" node is created. 72 struct Node { 73 public: 74 Node(int id = 0, std::optional<int> clusterId = std::nullopt) 75 : id(id), clusterId(clusterId) {} 76 77 int id; 78 std::optional<int> clusterId; 79 }; 80 81 /// This pass generates a Graphviz dataflow visualization of an MLIR operation. 82 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information 83 /// about the Graphviz DOT language. 84 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> { 85 public: 86 PrintOpPass(raw_ostream &os) : os(os) {} 87 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} 88 89 void runOnOperation() override { 90 emitGraph([&]() { 91 processOperation(getOperation()); 92 emitAllEdgeStmts(); 93 }); 94 } 95 96 /// Create a CFG graph for a region. Used in `Region::viewGraph`. 97 void emitRegionCFG(Region ®ion) { 98 printControlFlowEdges = true; 99 printDataFlowEdges = false; 100 emitGraph([&]() { processRegion(region); }); 101 } 102 103 private: 104 /// Emit all edges. This function should be called after all nodes have been 105 /// emitted. 106 void emitAllEdgeStmts() { 107 for (const std::string &edge : edges) 108 os << edge << ";\n"; 109 edges.clear(); 110 } 111 112 /// Emit a cluster (subgraph). The specified builder generates the body of the 113 /// cluster. Return the anchor node of the cluster. 114 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { 115 int clusterId = ++counter; 116 os << "subgraph cluster_" << clusterId << " {\n"; 117 os.indent(); 118 // Emit invisible anchor node from/to which arrows can be drawn. 119 Node anchorNode = emitNodeStmt(" ", kShapeNone); 120 os << attrStmt("label", quoteString(escapeString(std::move(label)))) 121 << ";\n"; 122 builder(); 123 os.unindent(); 124 os << "}\n"; 125 return Node(anchorNode.id, clusterId); 126 } 127 128 /// Generate an attribute statement. 129 std::string attrStmt(const Twine &key, const Twine &value) { 130 return (key + " = " + value).str(); 131 } 132 133 /// Emit an attribute list. 134 void emitAttrList(raw_ostream &os, const AttributeMap &map) { 135 os << "["; 136 interleaveComma(map, os, [&](const auto &it) { 137 os << this->attrStmt(it.first, it.second); 138 }); 139 os << "]"; 140 } 141 142 // Print an MLIR attribute to `os`. Large attributes are truncated. 143 void emitMlirAttr(raw_ostream &os, Attribute attr) { 144 // A value used to elide large container attribute. 145 int64_t largeAttrLimit = getLargeAttributeSizeLimit(); 146 147 // Always emit splat attributes. 148 if (attr.isa<SplatElementsAttr>()) { 149 attr.print(os); 150 return; 151 } 152 153 // Elide "big" elements attributes. 154 auto elements = attr.dyn_cast<ElementsAttr>(); 155 if (elements && elements.getNumElements() > largeAttrLimit) { 156 os << std::string(elements.getType().getRank(), '[') << "..." 157 << std::string(elements.getType().getRank(), ']') << " : " 158 << elements.getType(); 159 return; 160 } 161 162 auto array = attr.dyn_cast<ArrayAttr>(); 163 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { 164 os << "[...]"; 165 return; 166 } 167 168 // Print all other attributes. 169 std::string buf; 170 llvm::raw_string_ostream ss(buf); 171 attr.print(ss); 172 os << truncateString(ss.str()); 173 } 174 175 /// Append an edge to the list of edges. 176 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. 177 void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) { 178 AttributeMap attrs; 179 attrs["style"] = style.str(); 180 // Do not label edges that start/end at a cluster boundary. Such edges are 181 // clipped at the boundary, but labels are not. This can lead to labels 182 // floating around without any edge next to them. 183 if (!n1.clusterId && !n2.clusterId) 184 attrs["label"] = quoteString(escapeString(std::move(label))); 185 // Use `ltail` and `lhead` to draw edges between clusters. 186 if (n1.clusterId) 187 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); 188 if (n2.clusterId) 189 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); 190 191 edges.push_back(strFromOs([&](raw_ostream &os) { 192 os << llvm::format("v%i -> v%i ", n1.id, n2.id); 193 emitAttrList(os, attrs); 194 })); 195 } 196 197 /// Emit a graph. The specified builder generates the body of the graph. 198 void emitGraph(function_ref<void()> builder) { 199 os << "digraph G {\n"; 200 os.indent(); 201 // Edges between clusters are allowed only in compound mode. 202 os << attrStmt("compound", "true") << ";\n"; 203 builder(); 204 os.unindent(); 205 os << "}\n"; 206 } 207 208 /// Emit a node statement. 209 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) { 210 int nodeId = ++counter; 211 AttributeMap attrs; 212 attrs["label"] = quoteString(escapeString(std::move(label))); 213 attrs["shape"] = shape.str(); 214 os << llvm::format("v%i ", nodeId); 215 emitAttrList(os, attrs); 216 os << ";\n"; 217 return Node(nodeId); 218 } 219 220 /// Generate a label for an operation. 221 std::string getLabel(Operation *op) { 222 return strFromOs([&](raw_ostream &os) { 223 // Print operation name and type. 224 os << op->getName(); 225 if (printResultTypes) { 226 os << " : ("; 227 std::string buf; 228 llvm::raw_string_ostream ss(buf); 229 interleaveComma(op->getResultTypes(), ss); 230 os << truncateString(ss.str()) << ")"; 231 os << ")"; 232 } 233 234 // Print attributes. 235 if (printAttrs) { 236 os << "\n"; 237 for (const NamedAttribute &attr : op->getAttrs()) { 238 os << '\n' << attr.getName().getValue() << ": "; 239 emitMlirAttr(os, attr.getValue()); 240 } 241 } 242 }); 243 } 244 245 /// Generate a label for a block argument. 246 std::string getLabel(BlockArgument arg) { 247 return "arg" + std::to_string(arg.getArgNumber()); 248 } 249 250 /// Process a block. Emit a cluster and one node per block argument and 251 /// operation inside the cluster. 252 void processBlock(Block &block) { 253 emitClusterStmt([&]() { 254 for (BlockArgument &blockArg : block.getArguments()) 255 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); 256 257 // Emit a node for each operation. 258 std::optional<Node> prevNode; 259 for (Operation &op : block) { 260 Node nextNode = processOperation(&op); 261 if (printControlFlowEdges && prevNode) 262 emitEdgeStmt(*prevNode, nextNode, /*label=*/"", 263 kLineStyleControlFlow); 264 prevNode = nextNode; 265 } 266 }); 267 } 268 269 /// Process an operation. If the operation has regions, emit a cluster. 270 /// Otherwise, emit a node. 271 Node processOperation(Operation *op) { 272 Node node; 273 if (op->getNumRegions() > 0) { 274 // Emit cluster for op with regions. 275 node = emitClusterStmt( 276 [&]() { 277 for (Region ®ion : op->getRegions()) 278 processRegion(region); 279 }, 280 getLabel(op)); 281 } else { 282 node = emitNodeStmt(getLabel(op)); 283 } 284 285 // Insert data flow edges originating from each operand. 286 if (printDataFlowEdges) { 287 unsigned numOperands = op->getNumOperands(); 288 for (unsigned i = 0; i < numOperands; i++) 289 emitEdgeStmt(valueToNode[op->getOperand(i)], node, 290 /*label=*/numOperands == 1 ? "" : std::to_string(i), 291 kLineStyleDataFlow); 292 } 293 294 for (Value result : op->getResults()) 295 valueToNode[result] = node; 296 297 return node; 298 } 299 300 /// Process a region. 301 void processRegion(Region ®ion) { 302 for (Block &block : region.getBlocks()) 303 processBlock(block); 304 } 305 306 /// Truncate long strings. 307 std::string truncateString(std::string str) { 308 if (str.length() <= maxLabelLen) 309 return str; 310 return str.substr(0, maxLabelLen) + "..."; 311 } 312 313 /// Output stream to write DOT file to. 314 raw_indented_ostream os; 315 /// A list of edges. For simplicity, should be emitted after all nodes were 316 /// emitted. 317 std::vector<std::string> edges; 318 /// Mapping of SSA values to Graphviz nodes/clusters. 319 DenseMap<Value, Node> valueToNode; 320 /// Counter for generating unique node/subgraph identifiers. 321 int counter = 0; 322 }; 323 324 } // namespace 325 326 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) { 327 return std::make_unique<PrintOpPass>(os); 328 } 329 330 /// Generate a CFG for a region and show it in a window. 331 static void llvmViewGraph(Region ®ion, const Twine &name) { 332 int fd; 333 std::string filename = llvm::createGraphFilename(name.str(), fd); 334 { 335 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); 336 if (fd == -1) { 337 llvm::errs() << "error opening file '" << filename << "' for writing\n"; 338 return; 339 } 340 PrintOpPass pass(os); 341 pass.emitRegionCFG(region); 342 } 343 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT); 344 } 345 346 void mlir::Region::viewGraph(const Twine ®ionName) { 347 llvmViewGraph(*this, regionName); 348 } 349 350 void mlir::Region::viewGraph() { viewGraph("region"); } 351