xref: /llvm-project/llvm/utils/pipeline.py (revision b71edfaa4ec3c998aadb35255ce2f60bba2940b0)
1# Automatically formatted with yapf (https://github.com/google/yapf)
2"""Utility functions for creating and manipulating LLVM 'opt' NPM pipeline objects."""
3
4
5def fromStr(pipeStr):
6    """Create pipeline object from string representation."""
7    stack = []
8    curr = []
9    tok = ""
10    kind = ""
11    for c in pipeStr:
12        if c == ",":
13            if tok != "":
14                curr.append([None, tok])
15            tok = ""
16        elif c == "(":
17            stack.append([kind, curr])
18            kind = tok
19            curr = []
20            tok = ""
21        elif c == ")":
22            if tok != "":
23                curr.append([None, tok])
24            tok = ""
25            oldKind = kind
26            oldCurr = curr
27            [kind, curr] = stack.pop()
28            curr.append([oldKind, oldCurr])
29        else:
30            tok += c
31    if tok != "":
32        curr.append([None, tok])
33    return curr
34
35
36def toStr(pipeObj):
37    """Create string representation of pipeline object."""
38    res = ""
39    lastIdx = len(pipeObj) - 1
40    for i, c in enumerate(pipeObj):
41        if c[0]:
42            res += c[0] + "("
43            res += toStr(c[1])
44            res += ")"
45        else:
46            res += c[1]
47        if i != lastIdx:
48            res += ","
49    return res
50
51
52def count(pipeObj):
53    """Count number of passes (pass-managers excluded) in pipeline object."""
54    cnt = 0
55    for c in pipeObj:
56        if c[0]:
57            cnt += count(c[1])
58        else:
59            cnt += 1
60    return cnt
61
62
63def split(pipeObj, splitIndex):
64    """Create two new pipeline objects by splitting pipeObj in two directly after pass with index splitIndex."""
65
66    def splitInt(src, splitIndex, dstA, dstB, idx):
67        for s in src:
68            if s[0]:
69                dstA2 = []
70                dstB2 = []
71                idx = splitInt(s[1], splitIndex, dstA2, dstB2, idx)
72                dstA.append([s[0], dstA2])
73                dstB.append([s[0], dstB2])
74            else:
75                if idx <= splitIndex:
76                    dstA.append([None, s[1]])
77                else:
78                    dstB.append([None, s[1]])
79                idx += 1
80        return idx
81
82    listA = []
83    listB = []
84    splitInt(pipeObj, splitIndex, listA, listB, 0)
85    return [listA, listB]
86
87
88def remove(pipeObj, removeIndex):
89    """Create new pipeline object by removing pass with index removeIndex from pipeObj."""
90
91    def removeInt(src, removeIndex, dst, idx):
92        for s in src:
93            if s[0]:
94                dst2 = []
95                idx = removeInt(s[1], removeIndex, dst2, idx)
96                dst.append([s[0], dst2])
97            else:
98                if idx != removeIndex:
99                    dst.append([None, s[1]])
100                idx += 1
101        return idx
102
103    dst = []
104    removeInt(pipeObj, removeIndex, dst, 0)
105    return dst
106
107
108def copy(srcPipeObj):
109    """Create copy of pipeline object srcPipeObj."""
110
111    def copyInt(dst, src):
112        for s in src:
113            if s[0]:
114                dst2 = []
115                copyInt(dst2, s[1])
116                dst.append([s[0], dst2])
117            else:
118                dst.append([None, s[1]])
119
120    dstPipeObj = []
121    copyInt(dstPipeObj, srcPipeObj)
122    return dstPipeObj
123
124
125def prune(srcPipeObj):
126    """Create new pipeline object by removing empty pass-managers (those with count = 0) from srcPipeObj."""
127
128    def pruneInt(dst, src):
129        for s in src:
130            if s[0]:
131                if count(s[1]):
132                    dst2 = []
133                    pruneInt(dst2, s[1])
134                    dst.append([s[0], dst2])
135            else:
136                dst.append([None, s[1]])
137
138    dstPipeObj = []
139    pruneInt(dstPipeObj, srcPipeObj)
140    return dstPipeObj
141
142
143if __name__ == "__main__":
144    import unittest
145
146    class Test(unittest.TestCase):
147        def test_0(self):
148            pipeStr = "a,b,A(c,B(d,e),f),g"
149            pipeObj = fromStr(pipeStr)
150
151            self.assertEqual(7, count(pipeObj))
152
153            self.assertEqual(pipeObj, pipeObj)
154            self.assertEqual(pipeObj, prune(pipeObj))
155            self.assertEqual(pipeObj, copy(pipeObj))
156
157            self.assertEqual(pipeStr, toStr(pipeObj))
158            self.assertEqual(pipeStr, toStr(prune(pipeObj)))
159            self.assertEqual(pipeStr, toStr(copy(pipeObj)))
160
161            [pipeObjA, pipeObjB] = split(pipeObj, 3)
162            self.assertEqual("a,b,A(c,B(d))", toStr(pipeObjA))
163            self.assertEqual("A(B(e),f),g", toStr(pipeObjB))
164
165            self.assertEqual("b,A(c,B(d,e),f),g", toStr(remove(pipeObj, 0)))
166            self.assertEqual("a,b,A(c,B(d,e),f)", toStr(remove(pipeObj, 6)))
167
168            pipeObjC = remove(pipeObj, 4)
169            self.assertEqual("a,b,A(c,B(d),f),g", toStr(pipeObjC))
170            pipeObjC = remove(pipeObjC, 3)
171            self.assertEqual("a,b,A(c,B(),f),g", toStr(pipeObjC))
172            pipeObjC = prune(pipeObjC)
173            self.assertEqual("a,b,A(c,f),g", toStr(pipeObjC))
174
175    unittest.main()
176    exit(0)
177