1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/ViewOpGraph.h" 10 11 #include "mlir/IR/Block.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Support/IndentedOstream.h" 16 #include "llvm/Support/Format.h" 17 #include "llvm/Support/GraphWriter.h" 18 #include <map> 19 #include <optional> 20 #include <utility> 21 22 namespace mlir { 23 #define GEN_PASS_DEF_VIEWOPGRAPH 24 #include "mlir/Transforms/Passes.h.inc" 25 } // namespace mlir 26 27 using namespace mlir; 28 29 static const StringRef kLineStyleControlFlow = "dashed"; 30 static const StringRef kLineStyleDataFlow = "solid"; 31 static const StringRef kShapeNode = "ellipse"; 32 static const StringRef kShapeNone = "plain"; 33 34 /// Return the size limits for eliding large attributes. 35 static int64_t getLargeAttributeSizeLimit() { 36 // Use the default from the printer flags if possible. 37 if (std::optional<int64_t> limit = 38 OpPrintingFlags().getLargeElementsAttrLimit()) 39 return *limit; 40 return 16; 41 } 42 43 /// Return all values printed onto a stream as a string. 44 static std::string strFromOs(function_ref<void(raw_ostream &)> func) { 45 std::string buf; 46 llvm::raw_string_ostream os(buf); 47 func(os); 48 return os.str(); 49 } 50 51 /// Escape special characters such as '\n' and quotation marks. 52 static std::string escapeString(std::string str) { 53 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); }); 54 } 55 56 /// Put quotation marks around a given string. 57 static std::string quoteString(const std::string &str) { 58 return "\"" + str + "\""; 59 } 60 61 using AttributeMap = std::map<std::string, std::string>; 62 63 namespace { 64 65 /// This struct represents a node in the DOT language. Each node has an 66 /// identifier and an optional identifier for the cluster (subgraph) that 67 /// contains the node. 68 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but 69 /// not between clusters. However, edges can be clipped to the boundary of a 70 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new 71 /// cluster, an invisible "anchor" node is created. 72 struct Node { 73 public: 74 Node(int id = 0, std::optional<int> clusterId = std::nullopt) 75 : id(id), clusterId(clusterId) {} 76 77 int id; 78 std::optional<int> clusterId; 79 }; 80 81 /// This pass generates a Graphviz dataflow visualization of an MLIR operation. 82 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information 83 /// about the Graphviz DOT language. 84 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> { 85 public: 86 PrintOpPass(raw_ostream &os) : os(os) {} 87 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} 88 89 void runOnOperation() override { 90 initColorMapping(*getOperation()); 91 emitGraph([&]() { 92 processOperation(getOperation()); 93 emitAllEdgeStmts(); 94 }); 95 } 96 97 /// Create a CFG graph for a region. Used in `Region::viewGraph`. 98 void emitRegionCFG(Region ®ion) { 99 printControlFlowEdges = true; 100 printDataFlowEdges = false; 101 initColorMapping(region); 102 emitGraph([&]() { processRegion(region); }); 103 } 104 105 private: 106 /// Generate a color mapping that will color every operation with the same 107 /// name the same way. It'll interpolate the hue in the HSV color-space, 108 /// attempting to keep the contrast suitable for black text. 109 template <typename T> 110 void initColorMapping(T &irEntity) { 111 backgroundColors.clear(); 112 SmallVector<Operation *> ops; 113 irEntity.walk([&](Operation *op) { 114 auto &entry = backgroundColors[op->getName()]; 115 if (entry.first == 0) 116 ops.push_back(op); 117 ++entry.first; 118 }); 119 for (auto indexedOps : llvm::enumerate(ops)) { 120 double hue = ((double)indexedOps.index()) / ops.size(); 121 backgroundColors[indexedOps.value()->getName()].second = 122 std::to_string(hue) + " 1.0 1.0"; 123 } 124 } 125 126 /// Emit all edges. This function should be called after all nodes have been 127 /// emitted. 128 void emitAllEdgeStmts() { 129 for (const std::string &edge : edges) 130 os << edge << ";\n"; 131 edges.clear(); 132 } 133 134 /// Emit a cluster (subgraph). The specified builder generates the body of the 135 /// cluster. Return the anchor node of the cluster. 136 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { 137 int clusterId = ++counter; 138 os << "subgraph cluster_" << clusterId << " {\n"; 139 os.indent(); 140 // Emit invisible anchor node from/to which arrows can be drawn. 141 Node anchorNode = emitNodeStmt(" ", kShapeNone); 142 os << attrStmt("label", quoteString(escapeString(std::move(label)))) 143 << ";\n"; 144 builder(); 145 os.unindent(); 146 os << "}\n"; 147 return Node(anchorNode.id, clusterId); 148 } 149 150 /// Generate an attribute statement. 151 std::string attrStmt(const Twine &key, const Twine &value) { 152 return (key + " = " + value).str(); 153 } 154 155 /// Emit an attribute list. 156 void emitAttrList(raw_ostream &os, const AttributeMap &map) { 157 os << "["; 158 interleaveComma(map, os, [&](const auto &it) { 159 os << this->attrStmt(it.first, it.second); 160 }); 161 os << "]"; 162 } 163 164 // Print an MLIR attribute to `os`. Large attributes are truncated. 165 void emitMlirAttr(raw_ostream &os, Attribute attr) { 166 // A value used to elide large container attribute. 167 int64_t largeAttrLimit = getLargeAttributeSizeLimit(); 168 169 // Always emit splat attributes. 170 if (isa<SplatElementsAttr>(attr)) { 171 attr.print(os); 172 return; 173 } 174 175 // Elide "big" elements attributes. 176 auto elements = dyn_cast<ElementsAttr>(attr); 177 if (elements && elements.getNumElements() > largeAttrLimit) { 178 os << std::string(elements.getShapedType().getRank(), '[') << "..." 179 << std::string(elements.getShapedType().getRank(), ']') << " : " 180 << elements.getType(); 181 return; 182 } 183 184 auto array = dyn_cast<ArrayAttr>(attr); 185 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { 186 os << "[...]"; 187 return; 188 } 189 190 // Print all other attributes. 191 std::string buf; 192 llvm::raw_string_ostream ss(buf); 193 attr.print(ss); 194 os << truncateString(ss.str()); 195 } 196 197 /// Append an edge to the list of edges. 198 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. 199 void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) { 200 AttributeMap attrs; 201 attrs["style"] = style.str(); 202 // Do not label edges that start/end at a cluster boundary. Such edges are 203 // clipped at the boundary, but labels are not. This can lead to labels 204 // floating around without any edge next to them. 205 if (!n1.clusterId && !n2.clusterId) 206 attrs["label"] = quoteString(escapeString(std::move(label))); 207 // Use `ltail` and `lhead` to draw edges between clusters. 208 if (n1.clusterId) 209 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); 210 if (n2.clusterId) 211 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); 212 213 edges.push_back(strFromOs([&](raw_ostream &os) { 214 os << llvm::format("v%i -> v%i ", n1.id, n2.id); 215 emitAttrList(os, attrs); 216 })); 217 } 218 219 /// Emit a graph. The specified builder generates the body of the graph. 220 void emitGraph(function_ref<void()> builder) { 221 os << "digraph G {\n"; 222 os.indent(); 223 // Edges between clusters are allowed only in compound mode. 224 os << attrStmt("compound", "true") << ";\n"; 225 builder(); 226 os.unindent(); 227 os << "}\n"; 228 } 229 230 /// Emit a node statement. 231 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, 232 StringRef background = "") { 233 int nodeId = ++counter; 234 AttributeMap attrs; 235 attrs["label"] = quoteString(escapeString(std::move(label))); 236 attrs["shape"] = shape.str(); 237 if (!background.empty()) { 238 attrs["style"] = "filled"; 239 attrs["fillcolor"] = ("\"" + background + "\"").str(); 240 } 241 os << llvm::format("v%i ", nodeId); 242 emitAttrList(os, attrs); 243 os << ";\n"; 244 return Node(nodeId); 245 } 246 247 /// Generate a label for an operation. 248 std::string getLabel(Operation *op) { 249 return strFromOs([&](raw_ostream &os) { 250 // Print operation name and type. 251 os << op->getName(); 252 if (printResultTypes) { 253 os << " : ("; 254 std::string buf; 255 llvm::raw_string_ostream ss(buf); 256 interleaveComma(op->getResultTypes(), ss); 257 os << truncateString(ss.str()) << ")"; 258 } 259 260 // Print attributes. 261 if (printAttrs) { 262 os << "\n"; 263 for (const NamedAttribute &attr : op->getAttrs()) { 264 os << '\n' << attr.getName().getValue() << ": "; 265 emitMlirAttr(os, attr.getValue()); 266 } 267 } 268 }); 269 } 270 271 /// Generate a label for a block argument. 272 std::string getLabel(BlockArgument arg) { 273 return "arg" + std::to_string(arg.getArgNumber()); 274 } 275 276 /// Process a block. Emit a cluster and one node per block argument and 277 /// operation inside the cluster. 278 void processBlock(Block &block) { 279 emitClusterStmt([&]() { 280 for (BlockArgument &blockArg : block.getArguments()) 281 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); 282 283 // Emit a node for each operation. 284 std::optional<Node> prevNode; 285 for (Operation &op : block) { 286 Node nextNode = processOperation(&op); 287 if (printControlFlowEdges && prevNode) 288 emitEdgeStmt(*prevNode, nextNode, /*label=*/"", 289 kLineStyleControlFlow); 290 prevNode = nextNode; 291 } 292 }); 293 } 294 295 /// Process an operation. If the operation has regions, emit a cluster. 296 /// Otherwise, emit a node. 297 Node processOperation(Operation *op) { 298 Node node; 299 if (op->getNumRegions() > 0) { 300 // Emit cluster for op with regions. 301 node = emitClusterStmt( 302 [&]() { 303 for (Region ®ion : op->getRegions()) 304 processRegion(region); 305 }, 306 getLabel(op)); 307 } else { 308 node = emitNodeStmt(getLabel(op), kShapeNode, 309 backgroundColors[op->getName()].second); 310 } 311 312 // Insert data flow edges originating from each operand. 313 if (printDataFlowEdges) { 314 unsigned numOperands = op->getNumOperands(); 315 for (unsigned i = 0; i < numOperands; i++) 316 emitEdgeStmt(valueToNode[op->getOperand(i)], node, 317 /*label=*/numOperands == 1 ? "" : std::to_string(i), 318 kLineStyleDataFlow); 319 } 320 321 for (Value result : op->getResults()) 322 valueToNode[result] = node; 323 324 return node; 325 } 326 327 /// Process a region. 328 void processRegion(Region ®ion) { 329 for (Block &block : region.getBlocks()) 330 processBlock(block); 331 } 332 333 /// Truncate long strings. 334 std::string truncateString(std::string str) { 335 if (str.length() <= maxLabelLen) 336 return str; 337 return str.substr(0, maxLabelLen) + "..."; 338 } 339 340 /// Output stream to write DOT file to. 341 raw_indented_ostream os; 342 /// A list of edges. For simplicity, should be emitted after all nodes were 343 /// emitted. 344 std::vector<std::string> edges; 345 /// Mapping of SSA values to Graphviz nodes/clusters. 346 DenseMap<Value, Node> valueToNode; 347 /// Counter for generating unique node/subgraph identifiers. 348 int counter = 0; 349 350 DenseMap<OperationName, std::pair<int, std::string>> backgroundColors; 351 }; 352 353 } // namespace 354 355 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) { 356 return std::make_unique<PrintOpPass>(os); 357 } 358 359 /// Generate a CFG for a region and show it in a window. 360 static void llvmViewGraph(Region ®ion, const Twine &name) { 361 int fd; 362 std::string filename = llvm::createGraphFilename(name.str(), fd); 363 { 364 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); 365 if (fd == -1) { 366 llvm::errs() << "error opening file '" << filename << "' for writing\n"; 367 return; 368 } 369 PrintOpPass pass(os); 370 pass.emitRegionCFG(region); 371 } 372 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT); 373 } 374 375 void mlir::Region::viewGraph(const Twine ®ionName) { 376 llvmViewGraph(*this, regionName); 377 } 378 379 void mlir::Region::viewGraph() { viewGraph("region"); } 380