1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/ViewOpGraph.h" 10 #include "PassDetail.h" 11 #include "mlir/IR/Block.h" 12 #include "mlir/IR/Operation.h" 13 #include "mlir/Support/IndentedOstream.h" 14 #include "llvm/Support/Format.h" 15 #include "llvm/Support/GraphWriter.h" 16 17 using namespace mlir; 18 19 static const StringRef kLineStyleControlFlow = "dashed"; 20 static const StringRef kLineStyleDataFlow = "solid"; 21 static const StringRef kShapeNode = "ellipse"; 22 static const StringRef kShapeNone = "plain"; 23 24 /// Return the size limits for eliding large attributes. 25 static int64_t getLargeAttributeSizeLimit() { 26 // Use the default from the printer flags if possible. 27 if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit()) 28 return *limit; 29 return 16; 30 } 31 32 /// Return all values printed onto a stream as a string. 33 static std::string strFromOs(function_ref<void(raw_ostream &)> func) { 34 std::string buf; 35 llvm::raw_string_ostream os(buf); 36 func(os); 37 return os.str(); 38 } 39 40 /// Escape special characters such as '\n' and quotation marks. 41 static std::string escapeString(std::string str) { 42 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); }); 43 } 44 45 /// Put quotation marks around a given string. 46 static std::string quoteString(std::string str) { return "\"" + str + "\""; } 47 48 using AttributeMap = llvm::StringMap<std::string>; 49 50 namespace { 51 52 /// This struct represents a node in the DOT language. Each node has an 53 /// identifier and an optional identifier for the cluster (subgraph) that 54 /// contains the node. 55 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but 56 /// not between clusters. However, edges can be clipped to the boundary of a 57 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new 58 /// cluster, an invisible "anchor" node is created. 59 struct Node { 60 public: 61 Node(int id = 0, Optional<int> clusterId = llvm::None) 62 : id(id), clusterId(clusterId) {} 63 64 int id; 65 Optional<int> clusterId; 66 }; 67 68 /// This pass generates a Graphviz dataflow visualization of an MLIR operation. 69 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information 70 /// about the Graphviz DOT language. 71 class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> { 72 public: 73 PrintOpPass(raw_ostream &os) : os(os) {} 74 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} 75 76 void runOnOperation() override { 77 emitGraph([&]() { 78 processOperation(getOperation()); 79 emitAllEdgeStmts(); 80 }); 81 } 82 83 /// Create a CFG graph for a region. Used in `Region::viewGraph`. 84 void emitRegionCFG(Region ®ion) { 85 printControlFlowEdges = true; 86 printDataFlowEdges = false; 87 emitGraph([&]() { processRegion(region); }); 88 } 89 90 private: 91 /// Emit all edges. This function should be called after all nodes have been 92 /// emitted. 93 void emitAllEdgeStmts() { 94 for (const std::string &edge : edges) 95 os << edge << ";\n"; 96 edges.clear(); 97 } 98 99 /// Emit a cluster (subgraph). The specified builder generates the body of the 100 /// cluster. Return the anchor node of the cluster. 101 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { 102 int clusterId = ++counter; 103 os << "subgraph cluster_" << clusterId << " {\n"; 104 os.indent(); 105 // Emit invisible anchor node from/to which arrows can be drawn. 106 Node anchorNode = emitNodeStmt(" ", kShapeNone); 107 os << attrStmt("label", quoteString(escapeString(label))) << ";\n"; 108 builder(); 109 os.unindent(); 110 os << "}\n"; 111 return Node(anchorNode.id, clusterId); 112 } 113 114 /// Generate an attribute statement. 115 std::string attrStmt(const Twine &key, const Twine &value) { 116 return (key + " = " + value).str(); 117 } 118 119 /// Emit an attribute list. 120 void emitAttrList(raw_ostream &os, const AttributeMap &map) { 121 os << "["; 122 interleaveComma(map, os, [&](const auto &it) { 123 os << this->attrStmt(it.getKey(), it.getValue()); 124 }); 125 os << "]"; 126 } 127 128 // Print an MLIR attribute to `os`. Large attributes are truncated. 129 void emitMlirAttr(raw_ostream &os, Attribute attr) { 130 // A value used to elide large container attribute. 131 int64_t largeAttrLimit = getLargeAttributeSizeLimit(); 132 133 // Always emit splat attributes. 134 if (attr.isa<SplatElementsAttr>()) { 135 attr.print(os); 136 return; 137 } 138 139 // Elide "big" elements attributes. 140 auto elements = attr.dyn_cast<ElementsAttr>(); 141 if (elements && elements.getNumElements() > largeAttrLimit) { 142 os << std::string(elements.getType().getRank(), '[') << "..." 143 << std::string(elements.getType().getRank(), ']') << " : " 144 << elements.getType(); 145 return; 146 } 147 148 auto array = attr.dyn_cast<ArrayAttr>(); 149 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { 150 os << "[...]"; 151 return; 152 } 153 154 // Print all other attributes. 155 std::string buf; 156 llvm::raw_string_ostream ss(buf); 157 attr.print(ss); 158 os << truncateString(ss.str()); 159 } 160 161 /// Append an edge to the list of edges. 162 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. 163 void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) { 164 AttributeMap attrs; 165 attrs["style"] = style.str(); 166 // Do not label edges that start/end at a cluster boundary. Such edges are 167 // clipped at the boundary, but labels are not. This can lead to labels 168 // floating around without any edge next to them. 169 if (!n1.clusterId && !n2.clusterId) 170 attrs["label"] = quoteString(escapeString(label)); 171 // Use `ltail` and `lhead` to draw edges between clusters. 172 if (n1.clusterId) 173 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); 174 if (n2.clusterId) 175 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); 176 177 edges.push_back(strFromOs([&](raw_ostream &os) { 178 os << llvm::format("v%i -> v%i ", n1.id, n2.id); 179 emitAttrList(os, attrs); 180 })); 181 } 182 183 /// Emit a graph. The specified builder generates the body of the graph. 184 void emitGraph(function_ref<void()> builder) { 185 os << "digraph G {\n"; 186 os.indent(); 187 // Edges between clusters are allowed only in compound mode. 188 os << attrStmt("compound", "true") << ";\n"; 189 builder(); 190 os.unindent(); 191 os << "}\n"; 192 } 193 194 /// Emit a node statement. 195 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) { 196 int nodeId = ++counter; 197 AttributeMap attrs; 198 attrs["label"] = quoteString(escapeString(label)); 199 attrs["shape"] = shape.str(); 200 os << llvm::format("v%i ", nodeId); 201 emitAttrList(os, attrs); 202 os << ";\n"; 203 return Node(nodeId); 204 } 205 206 /// Generate a label for an operation. 207 std::string getLabel(Operation *op) { 208 return strFromOs([&](raw_ostream &os) { 209 // Print operation name and type. 210 os << op->getName(); 211 if (printResultTypes) { 212 os << " : ("; 213 std::string buf; 214 llvm::raw_string_ostream ss(buf); 215 interleaveComma(op->getResultTypes(), ss); 216 os << truncateString(ss.str()) << ")"; 217 os << ")"; 218 } 219 220 // Print attributes. 221 if (printAttrs) { 222 os << "\n"; 223 for (const NamedAttribute &attr : op->getAttrs()) { 224 os << '\n' << attr.getName().getValue() << ": "; 225 emitMlirAttr(os, attr.getValue()); 226 } 227 } 228 }); 229 } 230 231 /// Generate a label for a block argument. 232 std::string getLabel(BlockArgument arg) { 233 return "arg" + std::to_string(arg.getArgNumber()); 234 } 235 236 /// Process a block. Emit a cluster and one node per block argument and 237 /// operation inside the cluster. 238 void processBlock(Block &block) { 239 emitClusterStmt([&]() { 240 for (BlockArgument &blockArg : block.getArguments()) 241 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); 242 243 // Emit a node for each operation. 244 Optional<Node> prevNode; 245 for (Operation &op : block) { 246 Node nextNode = processOperation(&op); 247 if (printControlFlowEdges && prevNode) 248 emitEdgeStmt(*prevNode, nextNode, /*label=*/"", 249 kLineStyleControlFlow); 250 prevNode = nextNode; 251 } 252 }); 253 } 254 255 /// Process an operation. If the operation has regions, emit a cluster. 256 /// Otherwise, emit a node. 257 Node processOperation(Operation *op) { 258 Node node; 259 if (op->getNumRegions() > 0) { 260 // Emit cluster for op with regions. 261 node = emitClusterStmt( 262 [&]() { 263 for (Region ®ion : op->getRegions()) 264 processRegion(region); 265 }, 266 getLabel(op)); 267 } else { 268 node = emitNodeStmt(getLabel(op)); 269 } 270 271 // Insert data flow edges originating from each operand. 272 if (printDataFlowEdges) { 273 unsigned numOperands = op->getNumOperands(); 274 for (unsigned i = 0; i < numOperands; i++) 275 emitEdgeStmt(valueToNode[op->getOperand(i)], node, 276 /*label=*/numOperands == 1 ? "" : std::to_string(i), 277 kLineStyleDataFlow); 278 } 279 280 for (Value result : op->getResults()) 281 valueToNode[result] = node; 282 283 return node; 284 } 285 286 /// Process a region. 287 void processRegion(Region ®ion) { 288 for (Block &block : region.getBlocks()) 289 processBlock(block); 290 } 291 292 /// Truncate long strings. 293 std::string truncateString(std::string str) { 294 if (str.length() <= maxLabelLen) 295 return str; 296 return str.substr(0, maxLabelLen) + "..."; 297 } 298 299 /// Output stream to write DOT file to. 300 raw_indented_ostream os; 301 /// A list of edges. For simplicity, should be emitted after all nodes were 302 /// emitted. 303 std::vector<std::string> edges; 304 /// Mapping of SSA values to Graphviz nodes/clusters. 305 DenseMap<Value, Node> valueToNode; 306 /// Counter for generating unique node/subgraph identifiers. 307 int counter = 0; 308 }; 309 310 } // namespace 311 312 std::unique_ptr<Pass> 313 mlir::createPrintOpGraphPass(raw_ostream &os) { 314 return std::make_unique<PrintOpPass>(os); 315 } 316 317 /// Generate a CFG for a region and show it in a window. 318 static void llvmViewGraph(Region ®ion, const Twine &name) { 319 int fd; 320 std::string filename = llvm::createGraphFilename(name.str(), fd); 321 { 322 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); 323 if (fd == -1) { 324 llvm::errs() << "error opening file '" << filename << "' for writing\n"; 325 return; 326 } 327 PrintOpPass pass(os); 328 pass.emitRegionCFG(region); 329 } 330 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT); 331 } 332 333 void mlir::Region::viewGraph(const Twine ®ionName) { 334 llvmViewGraph(*this, regionName); 335 } 336 337 void mlir::Region::viewGraph() { viewGraph("region"); } 338