1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/ViewOpGraph.h" 10 11 #include "mlir/Analysis/TopologicalSortUtils.h" 12 #include "mlir/IR/Block.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Support/IndentedOstream.h" 17 #include "llvm/Support/Format.h" 18 #include "llvm/Support/GraphWriter.h" 19 #include <map> 20 #include <optional> 21 #include <utility> 22 23 namespace mlir { 24 #define GEN_PASS_DEF_VIEWOPGRAPH 25 #include "mlir/Transforms/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 30 static const StringRef kLineStyleControlFlow = "dashed"; 31 static const StringRef kLineStyleDataFlow = "solid"; 32 static const StringRef kShapeNode = "ellipse"; 33 static const StringRef kShapeNone = "plain"; 34 35 /// Return the size limits for eliding large attributes. 36 static int64_t getLargeAttributeSizeLimit() { 37 // Use the default from the printer flags if possible. 38 if (std::optional<int64_t> limit = 39 OpPrintingFlags().getLargeElementsAttrLimit()) 40 return *limit; 41 return 16; 42 } 43 44 /// Return all values printed onto a stream as a string. 45 static std::string strFromOs(function_ref<void(raw_ostream &)> func) { 46 std::string buf; 47 llvm::raw_string_ostream os(buf); 48 func(os); 49 return os.str(); 50 } 51 52 /// Escape special characters such as '\n' and quotation marks. 53 static std::string escapeString(std::string str) { 54 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); }); 55 } 56 57 /// Put quotation marks around a given string. 58 static std::string quoteString(const std::string &str) { 59 return "\"" + str + "\""; 60 } 61 62 using AttributeMap = std::map<std::string, std::string>; 63 64 namespace { 65 66 /// This struct represents a node in the DOT language. Each node has an 67 /// identifier and an optional identifier for the cluster (subgraph) that 68 /// contains the node. 69 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but 70 /// not between clusters. However, edges can be clipped to the boundary of a 71 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new 72 /// cluster, an invisible "anchor" node is created. 73 struct Node { 74 public: 75 Node(int id = 0, std::optional<int> clusterId = std::nullopt) 76 : id(id), clusterId(clusterId) {} 77 78 int id; 79 std::optional<int> clusterId; 80 }; 81 82 /// This pass generates a Graphviz dataflow visualization of an MLIR operation. 83 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information 84 /// about the Graphviz DOT language. 85 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> { 86 public: 87 PrintOpPass(raw_ostream &os) : os(os) {} 88 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} 89 90 void runOnOperation() override { 91 initColorMapping(*getOperation()); 92 emitGraph([&]() { 93 processOperation(getOperation()); 94 emitAllEdgeStmts(); 95 }); 96 } 97 98 /// Create a CFG graph for a region. Used in `Region::viewGraph`. 99 void emitRegionCFG(Region ®ion) { 100 printControlFlowEdges = true; 101 printDataFlowEdges = false; 102 initColorMapping(region); 103 emitGraph([&]() { processRegion(region); }); 104 } 105 106 private: 107 /// Generate a color mapping that will color every operation with the same 108 /// name the same way. It'll interpolate the hue in the HSV color-space, 109 /// attempting to keep the contrast suitable for black text. 110 template <typename T> 111 void initColorMapping(T &irEntity) { 112 backgroundColors.clear(); 113 SmallVector<Operation *> ops; 114 irEntity.walk([&](Operation *op) { 115 auto &entry = backgroundColors[op->getName()]; 116 if (entry.first == 0) 117 ops.push_back(op); 118 ++entry.first; 119 }); 120 for (auto indexedOps : llvm::enumerate(ops)) { 121 double hue = ((double)indexedOps.index()) / ops.size(); 122 backgroundColors[indexedOps.value()->getName()].second = 123 std::to_string(hue) + " 1.0 1.0"; 124 } 125 } 126 127 /// Emit all edges. This function should be called after all nodes have been 128 /// emitted. 129 void emitAllEdgeStmts() { 130 if (printDataFlowEdges) { 131 for (const auto &[value, node, label] : dataFlowEdges) { 132 emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow); 133 } 134 } 135 136 for (const std::string &edge : edges) 137 os << edge << ";\n"; 138 edges.clear(); 139 } 140 141 /// Emit a cluster (subgraph). The specified builder generates the body of the 142 /// cluster. Return the anchor node of the cluster. 143 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { 144 int clusterId = ++counter; 145 os << "subgraph cluster_" << clusterId << " {\n"; 146 os.indent(); 147 // Emit invisible anchor node from/to which arrows can be drawn. 148 Node anchorNode = emitNodeStmt(" ", kShapeNone); 149 os << attrStmt("label", quoteString(escapeString(std::move(label)))) 150 << ";\n"; 151 builder(); 152 os.unindent(); 153 os << "}\n"; 154 return Node(anchorNode.id, clusterId); 155 } 156 157 /// Generate an attribute statement. 158 std::string attrStmt(const Twine &key, const Twine &value) { 159 return (key + " = " + value).str(); 160 } 161 162 /// Emit an attribute list. 163 void emitAttrList(raw_ostream &os, const AttributeMap &map) { 164 os << "["; 165 interleaveComma(map, os, [&](const auto &it) { 166 os << this->attrStmt(it.first, it.second); 167 }); 168 os << "]"; 169 } 170 171 // Print an MLIR attribute to `os`. Large attributes are truncated. 172 void emitMlirAttr(raw_ostream &os, Attribute attr) { 173 // A value used to elide large container attribute. 174 int64_t largeAttrLimit = getLargeAttributeSizeLimit(); 175 176 // Always emit splat attributes. 177 if (isa<SplatElementsAttr>(attr)) { 178 attr.print(os); 179 return; 180 } 181 182 // Elide "big" elements attributes. 183 auto elements = dyn_cast<ElementsAttr>(attr); 184 if (elements && elements.getNumElements() > largeAttrLimit) { 185 os << std::string(elements.getShapedType().getRank(), '[') << "..." 186 << std::string(elements.getShapedType().getRank(), ']') << " : " 187 << elements.getType(); 188 return; 189 } 190 191 auto array = dyn_cast<ArrayAttr>(attr); 192 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { 193 os << "[...]"; 194 return; 195 } 196 197 // Print all other attributes. 198 std::string buf; 199 llvm::raw_string_ostream ss(buf); 200 attr.print(ss); 201 os << truncateString(ss.str()); 202 } 203 204 /// Append an edge to the list of edges. 205 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. 206 void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) { 207 AttributeMap attrs; 208 attrs["style"] = style.str(); 209 // Do not label edges that start/end at a cluster boundary. Such edges are 210 // clipped at the boundary, but labels are not. This can lead to labels 211 // floating around without any edge next to them. 212 if (!n1.clusterId && !n2.clusterId) 213 attrs["label"] = quoteString(escapeString(std::move(label))); 214 // Use `ltail` and `lhead` to draw edges between clusters. 215 if (n1.clusterId) 216 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); 217 if (n2.clusterId) 218 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); 219 220 edges.push_back(strFromOs([&](raw_ostream &os) { 221 os << llvm::format("v%i -> v%i ", n1.id, n2.id); 222 emitAttrList(os, attrs); 223 })); 224 } 225 226 /// Emit a graph. The specified builder generates the body of the graph. 227 void emitGraph(function_ref<void()> builder) { 228 os << "digraph G {\n"; 229 os.indent(); 230 // Edges between clusters are allowed only in compound mode. 231 os << attrStmt("compound", "true") << ";\n"; 232 builder(); 233 os.unindent(); 234 os << "}\n"; 235 } 236 237 /// Emit a node statement. 238 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, 239 StringRef background = "") { 240 int nodeId = ++counter; 241 AttributeMap attrs; 242 attrs["label"] = quoteString(escapeString(std::move(label))); 243 attrs["shape"] = shape.str(); 244 if (!background.empty()) { 245 attrs["style"] = "filled"; 246 attrs["fillcolor"] = ("\"" + background + "\"").str(); 247 } 248 os << llvm::format("v%i ", nodeId); 249 emitAttrList(os, attrs); 250 os << ";\n"; 251 return Node(nodeId); 252 } 253 254 /// Generate a label for an operation. 255 std::string getLabel(Operation *op) { 256 return strFromOs([&](raw_ostream &os) { 257 // Print operation name and type. 258 os << op->getName(); 259 if (printResultTypes) { 260 os << " : ("; 261 std::string buf; 262 llvm::raw_string_ostream ss(buf); 263 interleaveComma(op->getResultTypes(), ss); 264 os << truncateString(ss.str()) << ")"; 265 } 266 267 // Print attributes. 268 if (printAttrs) { 269 os << "\n"; 270 for (const NamedAttribute &attr : op->getAttrs()) { 271 os << '\n' << attr.getName().getValue() << ": "; 272 emitMlirAttr(os, attr.getValue()); 273 } 274 } 275 }); 276 } 277 278 /// Generate a label for a block argument. 279 std::string getLabel(BlockArgument arg) { 280 return "arg" + std::to_string(arg.getArgNumber()); 281 } 282 283 /// Process a block. Emit a cluster and one node per block argument and 284 /// operation inside the cluster. 285 void processBlock(Block &block) { 286 emitClusterStmt([&]() { 287 for (BlockArgument &blockArg : block.getArguments()) 288 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); 289 290 // Emit a node for each operation. 291 std::optional<Node> prevNode; 292 for (Operation &op : block) { 293 Node nextNode = processOperation(&op); 294 if (printControlFlowEdges && prevNode) 295 emitEdgeStmt(*prevNode, nextNode, /*label=*/"", 296 kLineStyleControlFlow); 297 prevNode = nextNode; 298 } 299 }); 300 } 301 302 /// Process an operation. If the operation has regions, emit a cluster. 303 /// Otherwise, emit a node. 304 Node processOperation(Operation *op) { 305 Node node; 306 if (op->getNumRegions() > 0) { 307 // Emit cluster for op with regions. 308 node = emitClusterStmt( 309 [&]() { 310 for (Region ®ion : op->getRegions()) 311 processRegion(region); 312 }, 313 getLabel(op)); 314 } else { 315 node = emitNodeStmt(getLabel(op), kShapeNode, 316 backgroundColors[op->getName()].second); 317 } 318 319 // Insert data flow edges originating from each operand. 320 if (printDataFlowEdges) { 321 unsigned numOperands = op->getNumOperands(); 322 for (unsigned i = 0; i < numOperands; i++) 323 dataFlowEdges.push_back({op->getOperand(i), node, 324 numOperands == 1 ? "" : std::to_string(i)}); 325 } 326 327 for (Value result : op->getResults()) 328 valueToNode[result] = node; 329 330 return node; 331 } 332 333 /// Process a region. 334 void processRegion(Region ®ion) { 335 for (Block &block : region.getBlocks()) 336 processBlock(block); 337 } 338 339 /// Truncate long strings. 340 std::string truncateString(std::string str) { 341 if (str.length() <= maxLabelLen) 342 return str; 343 return str.substr(0, maxLabelLen) + "..."; 344 } 345 346 /// Output stream to write DOT file to. 347 raw_indented_ostream os; 348 /// A list of edges. For simplicity, should be emitted after all nodes were 349 /// emitted. 350 std::vector<std::string> edges; 351 /// Mapping of SSA values to Graphviz nodes/clusters. 352 DenseMap<Value, Node> valueToNode; 353 /// Output for data flow edges is delayed until the end to handle cycles 354 std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges; 355 /// Counter for generating unique node/subgraph identifiers. 356 int counter = 0; 357 358 DenseMap<OperationName, std::pair<int, std::string>> backgroundColors; 359 }; 360 361 } // namespace 362 363 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) { 364 return std::make_unique<PrintOpPass>(os); 365 } 366 367 /// Generate a CFG for a region and show it in a window. 368 static void llvmViewGraph(Region ®ion, const Twine &name) { 369 int fd; 370 std::string filename = llvm::createGraphFilename(name.str(), fd); 371 { 372 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); 373 if (fd == -1) { 374 llvm::errs() << "error opening file '" << filename << "' for writing\n"; 375 return; 376 } 377 PrintOpPass pass(os); 378 pass.emitRegionCFG(region); 379 } 380 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT); 381 } 382 383 void mlir::Region::viewGraph(const Twine ®ionName) { 384 llvmViewGraph(*this, regionName); 385 } 386 387 void mlir::Region::viewGraph() { viewGraph("region"); } 388