xref: /llvm-project/mlir/test/python/dialects/llvm.py (revision 79d4d165638b7587937fc60431e0865fd73c9334)
1# RUN: %PYTHON %s | FileCheck %s
2# This is just a smoke test that the dialect is functional.
3
4from mlir.ir import *
5from mlir.dialects import llvm
6
7
8def constructAndPrintInModule(f):
9    print("\nTEST:", f.__name__)
10    with Context(), Location.unknown():
11        module = Module.create()
12        with InsertionPoint(module.body):
13            f()
14        print(module)
15    return f
16
17
18# CHECK-LABEL: testStructType
19@constructAndPrintInModule
20def testStructType():
21    print(llvm.StructType.get_literal([]))
22    # CHECK: !llvm.struct<()>
23
24    i8, i32, i64 = tuple(map(lambda x: IntegerType.get_signless(x), [8, 32, 64]))
25    print(llvm.StructType.get_literal([i8, i32, i64]))
26    print(llvm.StructType.get_literal([i32]))
27    print(llvm.StructType.get_literal([i32, i32], packed=True))
28    literal = llvm.StructType.get_literal([i8, i32, i64])
29    assert len(literal.body) == 3
30    print(*tuple(literal.body))
31    assert literal.name is None
32    # CHECK: !llvm.struct<(i8, i32, i64)>
33    # CHECK: !llvm.struct<(i32)>
34    # CHECK: !llvm.struct<packed (i32, i32)>
35    # CHECK: i8 i32 i64
36
37    assert llvm.StructType.get_literal([i32]) == llvm.StructType.get_literal([i32])
38    assert llvm.StructType.get_literal([i32]) != llvm.StructType.get_literal([i64])
39
40    print(llvm.StructType.get_identified("foo"))
41    print(llvm.StructType.get_identified("bar"))
42    # CHECK: !llvm.struct<"foo", opaque>
43    # CHECK: !llvm.struct<"bar", opaque>
44
45    assert llvm.StructType.get_identified("foo") == llvm.StructType.get_identified(
46        "foo"
47    )
48    assert llvm.StructType.get_identified("foo") != llvm.StructType.get_identified(
49        "bar"
50    )
51
52    foo_struct = llvm.StructType.get_identified("foo")
53    print(foo_struct.name)
54    print(foo_struct.body)
55    assert foo_struct.opaque
56    foo_struct.set_body([i32, i64])
57    print(*tuple(foo_struct.body))
58    print(foo_struct)
59    assert not foo_struct.packed
60    assert not foo_struct.opaque
61    assert llvm.StructType.get_identified("foo") == foo_struct
62    # CHECK: foo
63    # CHECK: None
64    # CHECK: i32 i64
65    # CHECK: !llvm.struct<"foo", (i32, i64)>
66
67    bar_struct = llvm.StructType.get_identified("bar")
68    bar_struct.set_body([i32], packed=True)
69    print(bar_struct)
70    assert bar_struct.packed
71    # CHECK: !llvm.struct<"bar", packed (i32)>
72
73    # Same body, should not raise.
74    foo_struct.set_body([i32, i64])
75
76    try:
77        foo_struct.set_body([])
78    except ValueError as e:
79        pass
80    else:
81        assert False, "expected exception not raised"
82
83    try:
84        bar_struct.set_body([i32])
85    except ValueError as e:
86        pass
87    else:
88        assert False, "expected exception not raised"
89
90    print(llvm.StructType.new_identified("foo", []))
91    assert llvm.StructType.new_identified("foo", []) != llvm.StructType.new_identified(
92        "foo", []
93    )
94    # CHECK: !llvm.struct<"foo{{[^"]+}}
95
96    opaque = llvm.StructType.get_opaque("opaque")
97    print(opaque)
98    assert opaque.opaque
99    # CHECK: !llvm.struct<"opaque", opaque>
100
101
102# CHECK-LABEL: testSmoke
103@constructAndPrintInModule
104def testSmoke():
105    mat64f32_t = Type.parse(
106        "!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
107    )
108    result = llvm.UndefOp(mat64f32_t)
109    # CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
110
111
112# CHECK-LABEL: testPointerType
113@constructAndPrintInModule
114def testPointerType():
115    ptr = llvm.PointerType.get()
116    # CHECK: !llvm.ptr
117    print(ptr)
118
119    ptr_with_addr = llvm.PointerType.get(1)
120    # CHECK: !llvm.ptr<1>
121    print(ptr_with_addr)
122
123
124# CHECK-LABEL: testConstant
125@constructAndPrintInModule
126def testConstant():
127    i32 = IntegerType.get_signless(32)
128    c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
129    # CHECK: %{{.*}} = llvm.mlir.constant(128 : i32) : i32
130    print(c_128.owner)
131
132
133# CHECK-LABEL: testIntrinsics
134@constructAndPrintInModule
135def testIntrinsics():
136    i32 = IntegerType.get_signless(32)
137    ptr = llvm.PointerType.get()
138    c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
139    # CHECK: %[[CST128:.*]] = llvm.mlir.constant(128 : i32) : i32
140    print(c_128.owner)
141
142    alloca = llvm.alloca(ptr, c_128, i32)
143    # CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CST128]] x i32 : (i32) -> !llvm.ptr
144    print(alloca.owner)
145
146    c_0 = llvm.mlir_constant(IntegerAttr.get(IntegerType.get_signless(8), 0))
147    # CHECK: %[[CST0:.+]] = llvm.mlir.constant(0 : i8) : i8
148    print(c_0.owner)
149
150    result = llvm.intr_memset(alloca, c_0, c_128, False)
151    # CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[CST0]], %[[CST128]]) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
152    print(result)
153