1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/ViewOpGraph.h" 10 #include "PassDetail.h" 11 #include "mlir/IR/Block.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Support/IndentedOstream.h" 15 #include "llvm/Support/Format.h" 16 #include "llvm/Support/GraphWriter.h" 17 #include <utility> 18 19 using namespace mlir; 20 21 static const StringRef kLineStyleControlFlow = "dashed"; 22 static const StringRef kLineStyleDataFlow = "solid"; 23 static const StringRef kShapeNode = "ellipse"; 24 static const StringRef kShapeNone = "plain"; 25 26 /// Return the size limits for eliding large attributes. 27 static int64_t getLargeAttributeSizeLimit() { 28 // Use the default from the printer flags if possible. 29 if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit()) 30 return *limit; 31 return 16; 32 } 33 34 /// Return all values printed onto a stream as a string. 35 static std::string strFromOs(function_ref<void(raw_ostream &)> func) { 36 std::string buf; 37 llvm::raw_string_ostream os(buf); 38 func(os); 39 return os.str(); 40 } 41 42 /// Escape special characters such as '\n' and quotation marks. 43 static std::string escapeString(std::string str) { 44 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); }); 45 } 46 47 /// Put quotation marks around a given string. 48 static std::string quoteString(const std::string &str) { 49 return "\"" + str + "\""; 50 } 51 52 using AttributeMap = llvm::StringMap<std::string>; 53 54 namespace { 55 56 /// This struct represents a node in the DOT language. Each node has an 57 /// identifier and an optional identifier for the cluster (subgraph) that 58 /// contains the node. 59 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but 60 /// not between clusters. However, edges can be clipped to the boundary of a 61 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new 62 /// cluster, an invisible "anchor" node is created. 63 struct Node { 64 public: 65 Node(int id = 0, Optional<int> clusterId = llvm::None) 66 : id(id), clusterId(clusterId) {} 67 68 int id; 69 Optional<int> clusterId; 70 }; 71 72 /// This pass generates a Graphviz dataflow visualization of an MLIR operation. 73 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information 74 /// about the Graphviz DOT language. 75 class PrintOpPass : public ViewOpGraphBase<PrintOpPass> { 76 public: 77 PrintOpPass(raw_ostream &os) : os(os) {} 78 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} 79 80 void runOnOperation() override { 81 emitGraph([&]() { 82 processOperation(getOperation()); 83 emitAllEdgeStmts(); 84 }); 85 } 86 87 /// Create a CFG graph for a region. Used in `Region::viewGraph`. 88 void emitRegionCFG(Region ®ion) { 89 printControlFlowEdges = true; 90 printDataFlowEdges = false; 91 emitGraph([&]() { processRegion(region); }); 92 } 93 94 private: 95 /// Emit all edges. This function should be called after all nodes have been 96 /// emitted. 97 void emitAllEdgeStmts() { 98 for (const std::string &edge : edges) 99 os << edge << ";\n"; 100 edges.clear(); 101 } 102 103 /// Emit a cluster (subgraph). The specified builder generates the body of the 104 /// cluster. Return the anchor node of the cluster. 105 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { 106 int clusterId = ++counter; 107 os << "subgraph cluster_" << clusterId << " {\n"; 108 os.indent(); 109 // Emit invisible anchor node from/to which arrows can be drawn. 110 Node anchorNode = emitNodeStmt(" ", kShapeNone); 111 os << attrStmt("label", quoteString(escapeString(std::move(label)))) 112 << ";\n"; 113 builder(); 114 os.unindent(); 115 os << "}\n"; 116 return Node(anchorNode.id, clusterId); 117 } 118 119 /// Generate an attribute statement. 120 std::string attrStmt(const Twine &key, const Twine &value) { 121 return (key + " = " + value).str(); 122 } 123 124 /// Emit an attribute list. 125 void emitAttrList(raw_ostream &os, const AttributeMap &map) { 126 os << "["; 127 interleaveComma(map, os, [&](const auto &it) { 128 os << this->attrStmt(it.getKey(), it.getValue()); 129 }); 130 os << "]"; 131 } 132 133 // Print an MLIR attribute to `os`. Large attributes are truncated. 134 void emitMlirAttr(raw_ostream &os, Attribute attr) { 135 // A value used to elide large container attribute. 136 int64_t largeAttrLimit = getLargeAttributeSizeLimit(); 137 138 // Always emit splat attributes. 139 if (attr.isa<SplatElementsAttr>()) { 140 attr.print(os); 141 return; 142 } 143 144 // Elide "big" elements attributes. 145 auto elements = attr.dyn_cast<ElementsAttr>(); 146 if (elements && elements.getNumElements() > largeAttrLimit) { 147 os << std::string(elements.getType().getRank(), '[') << "..." 148 << std::string(elements.getType().getRank(), ']') << " : " 149 << elements.getType(); 150 return; 151 } 152 153 auto array = attr.dyn_cast<ArrayAttr>(); 154 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { 155 os << "[...]"; 156 return; 157 } 158 159 // Print all other attributes. 160 std::string buf; 161 llvm::raw_string_ostream ss(buf); 162 attr.print(ss); 163 os << truncateString(ss.str()); 164 } 165 166 /// Append an edge to the list of edges. 167 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. 168 void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) { 169 AttributeMap attrs; 170 attrs["style"] = style.str(); 171 // Do not label edges that start/end at a cluster boundary. Such edges are 172 // clipped at the boundary, but labels are not. This can lead to labels 173 // floating around without any edge next to them. 174 if (!n1.clusterId && !n2.clusterId) 175 attrs["label"] = quoteString(escapeString(std::move(label))); 176 // Use `ltail` and `lhead` to draw edges between clusters. 177 if (n1.clusterId) 178 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); 179 if (n2.clusterId) 180 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); 181 182 edges.push_back(strFromOs([&](raw_ostream &os) { 183 os << llvm::format("v%i -> v%i ", n1.id, n2.id); 184 emitAttrList(os, attrs); 185 })); 186 } 187 188 /// Emit a graph. The specified builder generates the body of the graph. 189 void emitGraph(function_ref<void()> builder) { 190 os << "digraph G {\n"; 191 os.indent(); 192 // Edges between clusters are allowed only in compound mode. 193 os << attrStmt("compound", "true") << ";\n"; 194 builder(); 195 os.unindent(); 196 os << "}\n"; 197 } 198 199 /// Emit a node statement. 200 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) { 201 int nodeId = ++counter; 202 AttributeMap attrs; 203 attrs["label"] = quoteString(escapeString(std::move(label))); 204 attrs["shape"] = shape.str(); 205 os << llvm::format("v%i ", nodeId); 206 emitAttrList(os, attrs); 207 os << ";\n"; 208 return Node(nodeId); 209 } 210 211 /// Generate a label for an operation. 212 std::string getLabel(Operation *op) { 213 return strFromOs([&](raw_ostream &os) { 214 // Print operation name and type. 215 os << op->getName(); 216 if (printResultTypes) { 217 os << " : ("; 218 std::string buf; 219 llvm::raw_string_ostream ss(buf); 220 interleaveComma(op->getResultTypes(), ss); 221 os << truncateString(ss.str()) << ")"; 222 os << ")"; 223 } 224 225 // Print attributes. 226 if (printAttrs) { 227 os << "\n"; 228 for (const NamedAttribute &attr : op->getAttrs()) { 229 os << '\n' << attr.getName().getValue() << ": "; 230 emitMlirAttr(os, attr.getValue()); 231 } 232 } 233 }); 234 } 235 236 /// Generate a label for a block argument. 237 std::string getLabel(BlockArgument arg) { 238 return "arg" + std::to_string(arg.getArgNumber()); 239 } 240 241 /// Process a block. Emit a cluster and one node per block argument and 242 /// operation inside the cluster. 243 void processBlock(Block &block) { 244 emitClusterStmt([&]() { 245 for (BlockArgument &blockArg : block.getArguments()) 246 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); 247 248 // Emit a node for each operation. 249 Optional<Node> prevNode; 250 for (Operation &op : block) { 251 Node nextNode = processOperation(&op); 252 if (printControlFlowEdges && prevNode) 253 emitEdgeStmt(*prevNode, nextNode, /*label=*/"", 254 kLineStyleControlFlow); 255 prevNode = nextNode; 256 } 257 }); 258 } 259 260 /// Process an operation. If the operation has regions, emit a cluster. 261 /// Otherwise, emit a node. 262 Node processOperation(Operation *op) { 263 Node node; 264 if (op->getNumRegions() > 0) { 265 // Emit cluster for op with regions. 266 node = emitClusterStmt( 267 [&]() { 268 for (Region ®ion : op->getRegions()) 269 processRegion(region); 270 }, 271 getLabel(op)); 272 } else { 273 node = emitNodeStmt(getLabel(op)); 274 } 275 276 // Insert data flow edges originating from each operand. 277 if (printDataFlowEdges) { 278 unsigned numOperands = op->getNumOperands(); 279 for (unsigned i = 0; i < numOperands; i++) 280 emitEdgeStmt(valueToNode[op->getOperand(i)], node, 281 /*label=*/numOperands == 1 ? "" : std::to_string(i), 282 kLineStyleDataFlow); 283 } 284 285 for (Value result : op->getResults()) 286 valueToNode[result] = node; 287 288 return node; 289 } 290 291 /// Process a region. 292 void processRegion(Region ®ion) { 293 for (Block &block : region.getBlocks()) 294 processBlock(block); 295 } 296 297 /// Truncate long strings. 298 std::string truncateString(std::string str) { 299 if (str.length() <= maxLabelLen) 300 return str; 301 return str.substr(0, maxLabelLen) + "..."; 302 } 303 304 /// Output stream to write DOT file to. 305 raw_indented_ostream os; 306 /// A list of edges. For simplicity, should be emitted after all nodes were 307 /// emitted. 308 std::vector<std::string> edges; 309 /// Mapping of SSA values to Graphviz nodes/clusters. 310 DenseMap<Value, Node> valueToNode; 311 /// Counter for generating unique node/subgraph identifiers. 312 int counter = 0; 313 }; 314 315 } // namespace 316 317 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) { 318 return std::make_unique<PrintOpPass>(os); 319 } 320 321 /// Generate a CFG for a region and show it in a window. 322 static void llvmViewGraph(Region ®ion, const Twine &name) { 323 int fd; 324 std::string filename = llvm::createGraphFilename(name.str(), fd); 325 { 326 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); 327 if (fd == -1) { 328 llvm::errs() << "error opening file '" << filename << "' for writing\n"; 329 return; 330 } 331 PrintOpPass pass(os); 332 pass.emitRegionCFG(region); 333 } 334 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT); 335 } 336 337 void mlir::Region::viewGraph(const Twine ®ionName) { 338 llvmViewGraph(*this, regionName); 339 } 340 341 void mlir::Region::viewGraph() { viewGraph("region"); } 342