1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/ViewOpGraph.h" 10 11 #include "mlir/Analysis/TopologicalSortUtils.h" 12 #include "mlir/IR/Block.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Support/IndentedOstream.h" 17 #include "llvm/Support/Format.h" 18 #include "llvm/Support/GraphWriter.h" 19 #include <map> 20 #include <optional> 21 #include <utility> 22 23 namespace mlir { 24 #define GEN_PASS_DEF_VIEWOPGRAPH 25 #include "mlir/Transforms/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 30 static const StringRef kLineStyleControlFlow = "dashed"; 31 static const StringRef kLineStyleDataFlow = "solid"; 32 static const StringRef kShapeNode = "ellipse"; 33 static const StringRef kShapeNone = "plain"; 34 35 /// Return the size limits for eliding large attributes. 36 static int64_t getLargeAttributeSizeLimit() { 37 // Use the default from the printer flags if possible. 38 if (std::optional<int64_t> limit = 39 OpPrintingFlags().getLargeElementsAttrLimit()) 40 return *limit; 41 return 16; 42 } 43 44 /// Return all values printed onto a stream as a string. 45 static std::string strFromOs(function_ref<void(raw_ostream &)> func) { 46 std::string buf; 47 llvm::raw_string_ostream os(buf); 48 func(os); 49 return buf; 50 } 51 52 /// Escape special characters such as '\n' and quotation marks. 53 static std::string escapeString(std::string str) { 54 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); }); 55 } 56 57 /// Put quotation marks around a given string. 58 static std::string quoteString(const std::string &str) { 59 return "\"" + str + "\""; 60 } 61 62 using AttributeMap = std::map<std::string, std::string>; 63 64 namespace { 65 66 /// This struct represents a node in the DOT language. Each node has an 67 /// identifier and an optional identifier for the cluster (subgraph) that 68 /// contains the node. 69 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but 70 /// not between clusters. However, edges can be clipped to the boundary of a 71 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new 72 /// cluster, an invisible "anchor" node is created. 73 struct Node { 74 public: 75 Node(int id = 0, std::optional<int> clusterId = std::nullopt) 76 : id(id), clusterId(clusterId) {} 77 78 int id; 79 std::optional<int> clusterId; 80 }; 81 82 /// This pass generates a Graphviz dataflow visualization of an MLIR operation. 83 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information 84 /// about the Graphviz DOT language. 85 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> { 86 public: 87 PrintOpPass(raw_ostream &os) : os(os) {} 88 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} 89 90 void runOnOperation() override { 91 initColorMapping(*getOperation()); 92 emitGraph([&]() { 93 processOperation(getOperation()); 94 emitAllEdgeStmts(); 95 }); 96 markAllAnalysesPreserved(); 97 } 98 99 /// Create a CFG graph for a region. Used in `Region::viewGraph`. 100 void emitRegionCFG(Region ®ion) { 101 printControlFlowEdges = true; 102 printDataFlowEdges = false; 103 initColorMapping(region); 104 emitGraph([&]() { processRegion(region); }); 105 } 106 107 private: 108 /// Generate a color mapping that will color every operation with the same 109 /// name the same way. It'll interpolate the hue in the HSV color-space, 110 /// attempting to keep the contrast suitable for black text. 111 template <typename T> 112 void initColorMapping(T &irEntity) { 113 backgroundColors.clear(); 114 SmallVector<Operation *> ops; 115 irEntity.walk([&](Operation *op) { 116 auto &entry = backgroundColors[op->getName()]; 117 if (entry.first == 0) 118 ops.push_back(op); 119 ++entry.first; 120 }); 121 for (auto indexedOps : llvm::enumerate(ops)) { 122 double hue = ((double)indexedOps.index()) / ops.size(); 123 backgroundColors[indexedOps.value()->getName()].second = 124 std::to_string(hue) + " 1.0 1.0"; 125 } 126 } 127 128 /// Emit all edges. This function should be called after all nodes have been 129 /// emitted. 130 void emitAllEdgeStmts() { 131 if (printDataFlowEdges) { 132 for (const auto &[value, node, label] : dataFlowEdges) { 133 emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow); 134 } 135 } 136 137 for (const std::string &edge : edges) 138 os << edge << ";\n"; 139 edges.clear(); 140 } 141 142 /// Emit a cluster (subgraph). The specified builder generates the body of the 143 /// cluster. Return the anchor node of the cluster. 144 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { 145 int clusterId = ++counter; 146 os << "subgraph cluster_" << clusterId << " {\n"; 147 os.indent(); 148 // Emit invisible anchor node from/to which arrows can be drawn. 149 Node anchorNode = emitNodeStmt(" ", kShapeNone); 150 os << attrStmt("label", quoteString(escapeString(std::move(label)))) 151 << ";\n"; 152 builder(); 153 os.unindent(); 154 os << "}\n"; 155 return Node(anchorNode.id, clusterId); 156 } 157 158 /// Generate an attribute statement. 159 std::string attrStmt(const Twine &key, const Twine &value) { 160 return (key + " = " + value).str(); 161 } 162 163 /// Emit an attribute list. 164 void emitAttrList(raw_ostream &os, const AttributeMap &map) { 165 os << "["; 166 interleaveComma(map, os, [&](const auto &it) { 167 os << this->attrStmt(it.first, it.second); 168 }); 169 os << "]"; 170 } 171 172 // Print an MLIR attribute to `os`. Large attributes are truncated. 173 void emitMlirAttr(raw_ostream &os, Attribute attr) { 174 // A value used to elide large container attribute. 175 int64_t largeAttrLimit = getLargeAttributeSizeLimit(); 176 177 // Always emit splat attributes. 178 if (isa<SplatElementsAttr>(attr)) { 179 attr.print(os); 180 return; 181 } 182 183 // Elide "big" elements attributes. 184 auto elements = dyn_cast<ElementsAttr>(attr); 185 if (elements && elements.getNumElements() > largeAttrLimit) { 186 os << std::string(elements.getShapedType().getRank(), '[') << "..." 187 << std::string(elements.getShapedType().getRank(), ']') << " : " 188 << elements.getType(); 189 return; 190 } 191 192 auto array = dyn_cast<ArrayAttr>(attr); 193 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { 194 os << "[...]"; 195 return; 196 } 197 198 // Print all other attributes. 199 std::string buf; 200 llvm::raw_string_ostream ss(buf); 201 attr.print(ss); 202 os << truncateString(buf); 203 } 204 205 /// Append an edge to the list of edges. 206 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. 207 void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) { 208 AttributeMap attrs; 209 attrs["style"] = style.str(); 210 // Do not label edges that start/end at a cluster boundary. Such edges are 211 // clipped at the boundary, but labels are not. This can lead to labels 212 // floating around without any edge next to them. 213 if (!n1.clusterId && !n2.clusterId) 214 attrs["label"] = quoteString(escapeString(std::move(label))); 215 // Use `ltail` and `lhead` to draw edges between clusters. 216 if (n1.clusterId) 217 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); 218 if (n2.clusterId) 219 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); 220 221 edges.push_back(strFromOs([&](raw_ostream &os) { 222 os << llvm::format("v%i -> v%i ", n1.id, n2.id); 223 emitAttrList(os, attrs); 224 })); 225 } 226 227 /// Emit a graph. The specified builder generates the body of the graph. 228 void emitGraph(function_ref<void()> builder) { 229 os << "digraph G {\n"; 230 os.indent(); 231 // Edges between clusters are allowed only in compound mode. 232 os << attrStmt("compound", "true") << ";\n"; 233 builder(); 234 os.unindent(); 235 os << "}\n"; 236 } 237 238 /// Emit a node statement. 239 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, 240 StringRef background = "") { 241 int nodeId = ++counter; 242 AttributeMap attrs; 243 attrs["label"] = quoteString(escapeString(std::move(label))); 244 attrs["shape"] = shape.str(); 245 if (!background.empty()) { 246 attrs["style"] = "filled"; 247 attrs["fillcolor"] = ("\"" + background + "\"").str(); 248 } 249 os << llvm::format("v%i ", nodeId); 250 emitAttrList(os, attrs); 251 os << ";\n"; 252 return Node(nodeId); 253 } 254 255 /// Generate a label for an operation. 256 std::string getLabel(Operation *op) { 257 return strFromOs([&](raw_ostream &os) { 258 // Print operation name and type. 259 os << op->getName(); 260 if (printResultTypes) { 261 os << " : ("; 262 std::string buf; 263 llvm::raw_string_ostream ss(buf); 264 interleaveComma(op->getResultTypes(), ss); 265 os << truncateString(buf) << ")"; 266 } 267 268 // Print attributes. 269 if (printAttrs) { 270 os << "\n"; 271 for (const NamedAttribute &attr : op->getAttrs()) { 272 os << '\n' << attr.getName().getValue() << ": "; 273 emitMlirAttr(os, attr.getValue()); 274 } 275 } 276 }); 277 } 278 279 /// Generate a label for a block argument. 280 std::string getLabel(BlockArgument arg) { 281 return "arg" + std::to_string(arg.getArgNumber()); 282 } 283 284 /// Process a block. Emit a cluster and one node per block argument and 285 /// operation inside the cluster. 286 void processBlock(Block &block) { 287 emitClusterStmt([&]() { 288 for (BlockArgument &blockArg : block.getArguments()) 289 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); 290 291 // Emit a node for each operation. 292 std::optional<Node> prevNode; 293 for (Operation &op : block) { 294 Node nextNode = processOperation(&op); 295 if (printControlFlowEdges && prevNode) 296 emitEdgeStmt(*prevNode, nextNode, /*label=*/"", 297 kLineStyleControlFlow); 298 prevNode = nextNode; 299 } 300 }); 301 } 302 303 /// Process an operation. If the operation has regions, emit a cluster. 304 /// Otherwise, emit a node. 305 Node processOperation(Operation *op) { 306 Node node; 307 if (op->getNumRegions() > 0) { 308 // Emit cluster for op with regions. 309 node = emitClusterStmt( 310 [&]() { 311 for (Region ®ion : op->getRegions()) 312 processRegion(region); 313 }, 314 getLabel(op)); 315 } else { 316 node = emitNodeStmt(getLabel(op), kShapeNode, 317 backgroundColors[op->getName()].second); 318 } 319 320 // Insert data flow edges originating from each operand. 321 if (printDataFlowEdges) { 322 unsigned numOperands = op->getNumOperands(); 323 for (unsigned i = 0; i < numOperands; i++) 324 dataFlowEdges.push_back({op->getOperand(i), node, 325 numOperands == 1 ? "" : std::to_string(i)}); 326 } 327 328 for (Value result : op->getResults()) 329 valueToNode[result] = node; 330 331 return node; 332 } 333 334 /// Process a region. 335 void processRegion(Region ®ion) { 336 for (Block &block : region.getBlocks()) 337 processBlock(block); 338 } 339 340 /// Truncate long strings. 341 std::string truncateString(std::string str) { 342 if (str.length() <= maxLabelLen) 343 return str; 344 return str.substr(0, maxLabelLen) + "..."; 345 } 346 347 /// Output stream to write DOT file to. 348 raw_indented_ostream os; 349 /// A list of edges. For simplicity, should be emitted after all nodes were 350 /// emitted. 351 std::vector<std::string> edges; 352 /// Mapping of SSA values to Graphviz nodes/clusters. 353 DenseMap<Value, Node> valueToNode; 354 /// Output for data flow edges is delayed until the end to handle cycles 355 std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges; 356 /// Counter for generating unique node/subgraph identifiers. 357 int counter = 0; 358 359 DenseMap<OperationName, std::pair<int, std::string>> backgroundColors; 360 }; 361 362 } // namespace 363 364 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) { 365 return std::make_unique<PrintOpPass>(os); 366 } 367 368 /// Generate a CFG for a region and show it in a window. 369 static void llvmViewGraph(Region ®ion, const Twine &name) { 370 int fd; 371 std::string filename = llvm::createGraphFilename(name.str(), fd); 372 { 373 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); 374 if (fd == -1) { 375 llvm::errs() << "error opening file '" << filename << "' for writing\n"; 376 return; 377 } 378 PrintOpPass pass(os); 379 pass.emitRegionCFG(region); 380 } 381 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT); 382 } 383 384 void mlir::Region::viewGraph(const Twine ®ionName) { 385 llvmViewGraph(*this, regionName); 386 } 387 388 void mlir::Region::viewGraph() { viewGraph("region"); } 389