xref: /llvm-project/polly/utils/pyscop/isl.py (revision 5aafc6d58f3405662902cee006be11e599801b88)
1from ctypes import *
2
3isl = cdll.LoadLibrary("libisl.so")
4
5
6class Context:
7    defaultInstance = None
8    instances = {}
9
10    def __init__(self):
11        ptr = isl.isl_ctx_alloc()
12        self.ptr = ptr
13        Context.instances[ptr] = self
14
15    def __del__(self):
16        isl.isl_ctx_free(self)
17
18    def from_param(self):
19        return self.ptr
20
21    @staticmethod
22    def from_ptr(ptr):
23        return Context.instances[ptr]
24
25    @staticmethod
26    def getDefaultInstance():
27        if Context.defaultInstance is None:
28            Context.defaultInstance = Context()
29
30        return Context.defaultInstance
31
32
33class IslObject:
34    def __init__(self, string="", ctx=None, ptr=None):
35        self.initialize_isl_methods()
36        if ptr is not None:
37            self.ptr = ptr
38            self.ctx = self.get_isl_method("get_ctx")(self)
39            return
40
41        if ctx is None:
42            ctx = Context.getDefaultInstance()
43
44        self.ctx = ctx
45        self.ptr = self.get_isl_method("read_from_str")(ctx, string, -1)
46
47    def __del__(self):
48        self.get_isl_method("free")(self)
49
50    def from_param(self):
51        return self.ptr
52
53    @property
54    def context(self):
55        return self.ctx
56
57    def __repr__(self):
58        p = Printer(self.ctx)
59        self.to_printer(p)
60        return p.getString()
61
62    def __str__(self):
63        p = Printer(self.ctx)
64        self.to_printer(p)
65        return p.getString()
66
67    @staticmethod
68    def isl_name():
69        return "No isl name available"
70
71    def initialize_isl_methods(self):
72        if hasattr(self.__class__, "initialized"):
73            return
74
75        self.__class__.initialized = True
76        self.get_isl_method("read_from_str").argtypes = [Context, c_char_p, c_int]
77        self.get_isl_method("copy").argtypes = [self.__class__]
78        self.get_isl_method("copy").restype = c_int
79        self.get_isl_method("free").argtypes = [self.__class__]
80        self.get_isl_method("get_ctx").argtypes = [self.__class__]
81        self.get_isl_method("get_ctx").restype = Context.from_ptr
82        getattr(isl, "isl_printer_print_" + self.isl_name()).argtypes = [
83            Printer,
84            self.__class__,
85        ]
86
87    def get_isl_method(self, name):
88        return getattr(isl, "isl_" + self.isl_name() + "_" + name)
89
90    def to_printer(self, printer):
91        getattr(isl, "isl_printer_print_" + self.isl_name())(printer, self)
92
93
94class BSet(IslObject):
95    @staticmethod
96    def from_ptr(ptr):
97        if not ptr:
98            return
99        return BSet(ptr=ptr)
100
101    @staticmethod
102    def isl_name():
103        return "basic_set"
104
105
106class Set(IslObject):
107    @staticmethod
108    def from_ptr(ptr):
109        if not ptr:
110            return
111        return Set(ptr=ptr)
112
113    @staticmethod
114    def isl_name():
115        return "set"
116
117
118class USet(IslObject):
119    @staticmethod
120    def from_ptr(ptr):
121        if not ptr:
122            return
123        return USet(ptr=ptr)
124
125    @staticmethod
126    def isl_name():
127        return "union_set"
128
129
130class BMap(IslObject):
131    @staticmethod
132    def from_ptr(ptr):
133        if not ptr:
134            return
135        return BMap(ptr=ptr)
136
137    def __mul__(self, set):
138        return self.intersect_domain(set)
139
140    @staticmethod
141    def isl_name():
142        return "basic_map"
143
144
145class Map(IslObject):
146    @staticmethod
147    def from_ptr(ptr):
148        if not ptr:
149            return
150        return Map(ptr=ptr)
151
152    def __mul__(self, set):
153        return self.intersect_domain(set)
154
155    @staticmethod
156    def isl_name():
157        return "map"
158
159    @staticmethod
160    def lex_lt(dim):
161        dim = isl.isl_dim_copy(dim)
162        return isl.isl_map_lex_lt(dim)
163
164    @staticmethod
165    def lex_le(dim):
166        dim = isl.isl_dim_copy(dim)
167        return isl.isl_map_lex_le(dim)
168
169    @staticmethod
170    def lex_gt(dim):
171        dim = isl.isl_dim_copy(dim)
172        return isl.isl_map_lex_gt(dim)
173
174    @staticmethod
175    def lex_ge(dim):
176        dim = isl.isl_dim_copy(dim)
177        return isl.isl_map_lex_ge(dim)
178
179
180class UMap(IslObject):
181    @staticmethod
182    def from_ptr(ptr):
183        if not ptr:
184            return
185        return UMap(ptr=ptr)
186
187    @staticmethod
188    def isl_name():
189        return "union_map"
190
191
192class Dim(IslObject):
193    @staticmethod
194    def from_ptr(ptr):
195        if not ptr:
196            return
197        return Dim(ptr=ptr)
198
199    @staticmethod
200    def isl_name():
201        return "dim"
202
203    def initialize_isl_methods(self):
204        if hasattr(self.__class__, "initialized"):
205            return
206
207        self.__class__.initialized = True
208        self.get_isl_method("copy").argtypes = [self.__class__]
209        self.get_isl_method("copy").restype = c_int
210        self.get_isl_method("free").argtypes = [self.__class__]
211        self.get_isl_method("get_ctx").argtypes = [self.__class__]
212        self.get_isl_method("get_ctx").restype = Context.from_ptr
213
214    def __repr__(self):
215        return str(self)
216
217    def __str__(self):
218
219        dimParam = isl.isl_dim_size(self, 1)
220        dimIn = isl.isl_dim_size(self, 2)
221        dimOut = isl.isl_dim_size(self, 3)
222
223        if dimIn:
224            return "<dim In:%s, Out:%s, Param:%s>" % (dimIn, dimOut, dimParam)
225
226        return "<dim Set:%s, Param:%s>" % (dimOut, dimParam)
227
228
229class Printer:
230    FORMAT_ISL = 0
231    FORMAT_POLYLIB = 1
232    FORMAT_POLYLIB_CONSTRAINTS = 2
233    FORMAT_OMEGA = 3
234    FORMAT_C = 4
235    FORMAT_LATEX = 5
236    FORMAT_EXT_POLYLIB = 6
237
238    def __init__(self, ctx=None):
239        if ctx is None:
240            ctx = Context.getDefaultInstance()
241
242        self.ctx = ctx
243        self.ptr = isl.isl_printer_to_str(ctx)
244
245    def setFormat(self, format):
246        self.ptr = isl.isl_printer_set_output_format(self, format)
247
248    def from_param(self):
249        return self.ptr
250
251    def __del__(self):
252        isl.isl_printer_free(self)
253
254    def getString(self):
255        return isl.isl_printer_get_str(self)
256
257
258functions = [
259    # Unary properties
260    ("is_empty", BSet, [BSet], c_int),
261    ("is_empty", Set, [Set], c_int),
262    ("is_empty", USet, [USet], c_int),
263    ("is_empty", BMap, [BMap], c_int),
264    ("is_empty", Map, [Map], c_int),
265    ("is_empty", UMap, [UMap], c_int),
266    #         ("is_universe", Set, [Set], c_int),
267    #         ("is_universe", Map, [Map], c_int),
268    ("is_single_valued", Map, [Map], c_int),
269    ("is_bijective", Map, [Map], c_int),
270    ("is_wrapping", BSet, [BSet], c_int),
271    ("is_wrapping", Set, [Set], c_int),
272    # Binary properties
273    ("is_equal", BSet, [BSet, BSet], c_int),
274    ("is_equal", Set, [Set, Set], c_int),
275    ("is_equal", USet, [USet, USet], c_int),
276    ("is_equal", BMap, [BMap, BMap], c_int),
277    ("is_equal", Map, [Map, Map], c_int),
278    ("is_equal", UMap, [UMap, UMap], c_int),
279    # is_disjoint missing
280    # ("is_subset", BSet, [BSet, BSet], c_int),
281    ("is_subset", Set, [Set, Set], c_int),
282    ("is_subset", USet, [USet, USet], c_int),
283    ("is_subset", BMap, [BMap, BMap], c_int),
284    ("is_subset", Map, [Map, Map], c_int),
285    ("is_subset", UMap, [UMap, UMap], c_int),
286    # ("is_strict_subset", BSet, [BSet, BSet], c_int),
287    ("is_strict_subset", Set, [Set, Set], c_int),
288    ("is_strict_subset", USet, [USet, USet], c_int),
289    ("is_strict_subset", BMap, [BMap, BMap], c_int),
290    ("is_strict_subset", Map, [Map, Map], c_int),
291    ("is_strict_subset", UMap, [UMap, UMap], c_int),
292    # Unary Operations
293    ("complement", Set, [Set], Set),
294    ("reverse", BMap, [BMap], BMap),
295    ("reverse", Map, [Map], Map),
296    ("reverse", UMap, [UMap], UMap),
297    # Projection missing
298    ("range", BMap, [BMap], BSet),
299    ("range", Map, [Map], Set),
300    ("range", UMap, [UMap], USet),
301    ("domain", BMap, [BMap], BSet),
302    ("domain", Map, [Map], Set),
303    ("domain", UMap, [UMap], USet),
304    ("identity", Set, [Set], Map),
305    ("identity", USet, [USet], UMap),
306    ("deltas", BMap, [BMap], BSet),
307    ("deltas", Map, [Map], Set),
308    ("deltas", UMap, [UMap], USet),
309    ("coalesce", Set, [Set], Set),
310    ("coalesce", USet, [USet], USet),
311    ("coalesce", Map, [Map], Map),
312    ("coalesce", UMap, [UMap], UMap),
313    ("detect_equalities", BSet, [BSet], BSet),
314    ("detect_equalities", Set, [Set], Set),
315    ("detect_equalities", USet, [USet], USet),
316    ("detect_equalities", BMap, [BMap], BMap),
317    ("detect_equalities", Map, [Map], Map),
318    ("detect_equalities", UMap, [UMap], UMap),
319    ("convex_hull", Set, [Set], Set),
320    ("convex_hull", Map, [Map], Map),
321    ("simple_hull", Set, [Set], Set),
322    ("simple_hull", Map, [Map], Map),
323    ("affine_hull", BSet, [BSet], BSet),
324    ("affine_hull", Set, [Set], BSet),
325    ("affine_hull", USet, [USet], USet),
326    ("affine_hull", BMap, [BMap], BMap),
327    ("affine_hull", Map, [Map], BMap),
328    ("affine_hull", UMap, [UMap], UMap),
329    ("polyhedral_hull", Set, [Set], Set),
330    ("polyhedral_hull", USet, [USet], USet),
331    ("polyhedral_hull", Map, [Map], Map),
332    ("polyhedral_hull", UMap, [UMap], UMap),
333    # Power missing
334    # Transitive closure missing
335    # Reaching path lengths missing
336    ("wrap", BMap, [BMap], BSet),
337    ("wrap", Map, [Map], Set),
338    ("wrap", UMap, [UMap], USet),
339    ("unwrap", BSet, [BMap], BMap),
340    ("unwrap", Set, [Map], Map),
341    ("unwrap", USet, [UMap], UMap),
342    ("flatten", Set, [Set], Set),
343    ("flatten", Map, [Map], Map),
344    ("flatten_map", Set, [Set], Map),
345    # Dimension manipulation missing
346    # Binary Operations
347    ("intersect", BSet, [BSet, BSet], BSet),
348    ("intersect", Set, [Set, Set], Set),
349    ("intersect", USet, [USet, USet], USet),
350    ("intersect", BMap, [BMap, BMap], BMap),
351    ("intersect", Map, [Map, Map], Map),
352    ("intersect", UMap, [UMap, UMap], UMap),
353    ("intersect_domain", BMap, [BMap, BSet], BMap),
354    ("intersect_domain", Map, [Map, Set], Map),
355    ("intersect_domain", UMap, [UMap, USet], UMap),
356    ("intersect_range", BMap, [BMap, BSet], BMap),
357    ("intersect_range", Map, [Map, Set], Map),
358    ("intersect_range", UMap, [UMap, USet], UMap),
359    ("union", BSet, [BSet, BSet], Set),
360    ("union", Set, [Set, Set], Set),
361    ("union", USet, [USet, USet], USet),
362    ("union", BMap, [BMap, BMap], Map),
363    ("union", Map, [Map, Map], Map),
364    ("union", UMap, [UMap, UMap], UMap),
365    ("subtract", Set, [Set, Set], Set),
366    ("subtract", Map, [Map, Map], Map),
367    ("subtract", USet, [USet, USet], USet),
368    ("subtract", UMap, [UMap, UMap], UMap),
369    ("apply", BSet, [BSet, BMap], BSet),
370    ("apply", Set, [Set, Map], Set),
371    ("apply", USet, [USet, UMap], USet),
372    ("apply_domain", BMap, [BMap, BMap], BMap),
373    ("apply_domain", Map, [Map, Map], Map),
374    ("apply_domain", UMap, [UMap, UMap], UMap),
375    ("apply_range", BMap, [BMap, BMap], BMap),
376    ("apply_range", Map, [Map, Map], Map),
377    ("apply_range", UMap, [UMap, UMap], UMap),
378    ("gist", BSet, [BSet, BSet], BSet),
379    ("gist", Set, [Set, Set], Set),
380    ("gist", USet, [USet, USet], USet),
381    ("gist", BMap, [BMap, BMap], BMap),
382    ("gist", Map, [Map, Map], Map),
383    ("gist", UMap, [UMap, UMap], UMap),
384    # Lexicographic Optimizations
385    # partial_lexmin missing
386    ("lexmin", BSet, [BSet], BSet),
387    ("lexmin", Set, [Set], Set),
388    ("lexmin", USet, [USet], USet),
389    ("lexmin", BMap, [BMap], BMap),
390    ("lexmin", Map, [Map], Map),
391    ("lexmin", UMap, [UMap], UMap),
392    ("lexmax", BSet, [BSet], BSet),
393    ("lexmax", Set, [Set], Set),
394    ("lexmax", USet, [USet], USet),
395    ("lexmax", BMap, [BMap], BMap),
396    ("lexmax", Map, [Map], Map),
397    ("lexmax", UMap, [UMap], UMap),
398    # Undocumented
399    ("lex_lt_union_set", USet, [USet, USet], UMap),
400    ("lex_le_union_set", USet, [USet, USet], UMap),
401    ("lex_gt_union_set", USet, [USet, USet], UMap),
402    ("lex_ge_union_set", USet, [USet, USet], UMap),
403]
404keep_functions = [
405    # Unary properties
406    ("get_dim", BSet, [BSet], Dim),
407    ("get_dim", Set, [Set], Dim),
408    ("get_dim", USet, [USet], Dim),
409    ("get_dim", BMap, [BMap], Dim),
410    ("get_dim", Map, [Map], Dim),
411    ("get_dim", UMap, [UMap], Dim),
412]
413
414
415def addIslFunction(object, name):
416    functionName = "isl_" + object.isl_name() + "_" + name
417    islFunction = getattr(isl, functionName)
418    if len(islFunction.argtypes) == 1:
419        f = lambda a: islFunctionOneOp(islFunction, a)
420    elif len(islFunction.argtypes) == 2:
421        f = lambda a, b: islFunctionTwoOp(islFunction, a, b)
422    object.__dict__[name] = f
423
424
425def islFunctionOneOp(islFunction, ops):
426    ops = getattr(isl, "isl_" + ops.isl_name() + "_copy")(ops)
427    return islFunction(ops)
428
429
430def islFunctionTwoOp(islFunction, opOne, opTwo):
431    opOne = getattr(isl, "isl_" + opOne.isl_name() + "_copy")(opOne)
432    opTwo = getattr(isl, "isl_" + opTwo.isl_name() + "_copy")(opTwo)
433    return islFunction(opOne, opTwo)
434
435
436for (operation, base, operands, ret) in functions:
437    functionName = "isl_" + base.isl_name() + "_" + operation
438    islFunction = getattr(isl, functionName)
439    if len(operands) == 1:
440        islFunction.argtypes = [c_int]
441    elif len(operands) == 2:
442        islFunction.argtypes = [c_int, c_int]
443
444    if ret == c_int:
445        islFunction.restype = ret
446    else:
447        islFunction.restype = ret.from_ptr
448
449    addIslFunction(base, operation)
450
451
452def addIslFunctionKeep(object, name):
453    functionName = "isl_" + object.isl_name() + "_" + name
454    islFunction = getattr(isl, functionName)
455    if len(islFunction.argtypes) == 1:
456        f = lambda a: islFunctionOneOpKeep(islFunction, a)
457    elif len(islFunction.argtypes) == 2:
458        f = lambda a, b: islFunctionTwoOpKeep(islFunction, a, b)
459    object.__dict__[name] = f
460
461
462def islFunctionOneOpKeep(islFunction, ops):
463    return islFunction(ops)
464
465
466def islFunctionTwoOpKeep(islFunction, opOne, opTwo):
467    return islFunction(opOne, opTwo)
468
469
470for (operation, base, operands, ret) in keep_functions:
471    functionName = "isl_" + base.isl_name() + "_" + operation
472    islFunction = getattr(isl, functionName)
473    if len(operands) == 1:
474        islFunction.argtypes = [c_int]
475    elif len(operands) == 2:
476        islFunction.argtypes = [c_int, c_int]
477
478    if ret == c_int:
479        islFunction.restype = ret
480    else:
481        islFunction.restype = ret.from_ptr
482
483    addIslFunctionKeep(base, operation)
484
485isl.isl_ctx_free.argtypes = [Context]
486isl.isl_basic_set_read_from_str.argtypes = [Context, c_char_p, c_int]
487isl.isl_set_read_from_str.argtypes = [Context, c_char_p, c_int]
488isl.isl_basic_set_copy.argtypes = [BSet]
489isl.isl_basic_set_copy.restype = c_int
490isl.isl_set_copy.argtypes = [Set]
491isl.isl_set_copy.restype = c_int
492isl.isl_set_copy.argtypes = [Set]
493isl.isl_set_copy.restype = c_int
494isl.isl_set_free.argtypes = [Set]
495isl.isl_basic_set_get_ctx.argtypes = [BSet]
496isl.isl_basic_set_get_ctx.restype = Context.from_ptr
497isl.isl_set_get_ctx.argtypes = [Set]
498isl.isl_set_get_ctx.restype = Context.from_ptr
499isl.isl_basic_set_get_dim.argtypes = [BSet]
500isl.isl_basic_set_get_dim.restype = Dim.from_ptr
501isl.isl_set_get_dim.argtypes = [Set]
502isl.isl_set_get_dim.restype = Dim.from_ptr
503isl.isl_union_set_get_dim.argtypes = [USet]
504isl.isl_union_set_get_dim.restype = Dim.from_ptr
505
506isl.isl_basic_map_read_from_str.argtypes = [Context, c_char_p, c_int]
507isl.isl_map_read_from_str.argtypes = [Context, c_char_p, c_int]
508isl.isl_basic_map_free.argtypes = [BMap]
509isl.isl_map_free.argtypes = [Map]
510isl.isl_basic_map_copy.argtypes = [BMap]
511isl.isl_basic_map_copy.restype = c_int
512isl.isl_map_copy.argtypes = [Map]
513isl.isl_map_copy.restype = c_int
514isl.isl_map_get_ctx.argtypes = [Map]
515isl.isl_basic_map_get_ctx.argtypes = [BMap]
516isl.isl_basic_map_get_ctx.restype = Context.from_ptr
517isl.isl_map_get_ctx.argtypes = [Map]
518isl.isl_map_get_ctx.restype = Context.from_ptr
519isl.isl_basic_map_get_dim.argtypes = [BMap]
520isl.isl_basic_map_get_dim.restype = Dim.from_ptr
521isl.isl_map_get_dim.argtypes = [Map]
522isl.isl_map_get_dim.restype = Dim.from_ptr
523isl.isl_union_map_get_dim.argtypes = [UMap]
524isl.isl_union_map_get_dim.restype = Dim.from_ptr
525isl.isl_printer_free.argtypes = [Printer]
526isl.isl_printer_to_str.argtypes = [Context]
527isl.isl_printer_print_basic_set.argtypes = [Printer, BSet]
528isl.isl_printer_print_set.argtypes = [Printer, Set]
529isl.isl_printer_print_basic_map.argtypes = [Printer, BMap]
530isl.isl_printer_print_map.argtypes = [Printer, Map]
531isl.isl_printer_get_str.argtypes = [Printer]
532isl.isl_printer_get_str.restype = c_char_p
533isl.isl_printer_set_output_format.argtypes = [Printer, c_int]
534isl.isl_printer_set_output_format.restype = c_int
535isl.isl_dim_size.argtypes = [Dim, c_int]
536isl.isl_dim_size.restype = c_int
537
538isl.isl_map_lex_lt.argtypes = [c_int]
539isl.isl_map_lex_lt.restype = Map.from_ptr
540isl.isl_map_lex_le.argtypes = [c_int]
541isl.isl_map_lex_le.restype = Map.from_ptr
542isl.isl_map_lex_gt.argtypes = [c_int]
543isl.isl_map_lex_gt.restype = Map.from_ptr
544isl.isl_map_lex_ge.argtypes = [c_int]
545isl.isl_map_lex_ge.restype = Map.from_ptr
546
547isl.isl_union_map_compute_flow.argtypes = [
548    c_int,
549    c_int,
550    c_int,
551    c_int,
552    c_void_p,
553    c_void_p,
554    c_void_p,
555    c_void_p,
556]
557
558
559def dependences(sink, must_source, may_source, schedule):
560    sink = getattr(isl, "isl_" + sink.isl_name() + "_copy")(sink)
561    must_source = getattr(isl, "isl_" + must_source.isl_name() + "_copy")(must_source)
562    may_source = getattr(isl, "isl_" + may_source.isl_name() + "_copy")(may_source)
563    schedule = getattr(isl, "isl_" + schedule.isl_name() + "_copy")(schedule)
564    must_dep = c_int()
565    may_dep = c_int()
566    must_no_source = c_int()
567    may_no_source = c_int()
568    isl.isl_union_map_compute_flow(
569        sink,
570        must_source,
571        may_source,
572        schedule,
573        byref(must_dep),
574        byref(may_dep),
575        byref(must_no_source),
576        byref(may_no_source),
577    )
578
579    return (
580        UMap.from_ptr(must_dep),
581        UMap.from_ptr(may_dep),
582        USet.from_ptr(must_no_source),
583        USet.from_ptr(may_no_source),
584    )
585
586
587__all__ = ["Set", "Map", "Printer", "Context"]
588