xref: /llvm-project/mlir/lib/IR/Unit.cpp (revision 68f58812e3e99e31d77c0c23b6298489444dc0be)
1 //===- Unit.cpp - Support for manipulating IR Unit ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Unit.h"
10 #include "mlir/IR/Operation.h"
11 #include "mlir/IR/OperationSupport.h"
12 #include "mlir/IR/Region.h"
13 #include "llvm/Support/raw_ostream.h"
14 #include <iterator>
15 #include <sstream>
16 
17 using namespace mlir;
18 
printOp(llvm::raw_ostream & os,Operation * op,OpPrintingFlags & flags)19 static void printOp(llvm::raw_ostream &os, Operation *op,
20                     OpPrintingFlags &flags) {
21   if (!op) {
22     os << "<Operation:nullptr>";
23     return;
24   }
25   op->print(os, flags);
26 }
27 
printRegion(llvm::raw_ostream & os,Region * region,OpPrintingFlags & flags)28 static void printRegion(llvm::raw_ostream &os, Region *region,
29                         OpPrintingFlags &flags) {
30   if (!region) {
31     os << "<Region:nullptr>";
32     return;
33   }
34   os << "Region #" << region->getRegionNumber() << " for op ";
35   printOp(os, region->getParentOp(), flags);
36 }
37 
printBlock(llvm::raw_ostream & os,Block * block,OpPrintingFlags & flags)38 static void printBlock(llvm::raw_ostream &os, Block *block,
39                        OpPrintingFlags &flags) {
40   Region *region = block->getParent();
41   Block *entry = &region->front();
42   int blockId = std::distance(entry->getIterator(), block->getIterator());
43   os << "Block #" << blockId << " for ";
44   bool shouldSkipRegions = flags.shouldSkipRegions();
45   printRegion(os, region, flags.skipRegions());
46   if (!shouldSkipRegions)
47     block->print(os);
48 }
49 
print(llvm::raw_ostream & os,OpPrintingFlags flags) const50 void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
51   if (auto *op = llvm::dyn_cast_if_present<Operation *>(*this))
52     return printOp(os, op, flags);
53   if (auto *region = llvm::dyn_cast_if_present<Region *>(*this))
54     return printRegion(os, region, flags);
55   if (auto *block = llvm::dyn_cast_if_present<Block *>(*this))
56     return printBlock(os, block, flags);
57   llvm_unreachable("unknown IRUnit");
58 }
59 
operator <<(llvm::raw_ostream & os,const IRUnit & unit)60 llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, const IRUnit &unit) {
61   unit.print(os);
62   return os;
63 }
64