xref: /llvm-project/mlir/python/mlir/extras/types.py (revision e17c91341be2f6a2d229ab44a4290e7d0ef2e094)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from functools import partial
6from typing import Optional, List
7
8from ..ir import (
9    Attribute,
10    BF16Type,
11    ComplexType,
12    F16Type,
13    F32Type,
14    F64Type,
15    Float4E2M1FNType,
16    Float6E2M3FNType,
17    Float6E3M2FNType,
18    Float8E3M4Type,
19    Float8E4M3B11FNUZType,
20    Float8E4M3FNType,
21    Float8E4M3Type,
22    Float8E5M2Type,
23    Float8E8M0FNUType,
24    FloatTF32Type,
25    FunctionType,
26    IndexType,
27    IntegerType,
28    MemRefType,
29    NoneType,
30    OpaqueType,
31    RankedTensorType,
32    StridedLayoutAttr,
33    StringAttr,
34    TupleType,
35    Type,
36    UnrankedMemRefType,
37    UnrankedTensorType,
38    VectorType,
39)
40
41index = lambda: IndexType.get()
42
43
44def i(width):
45    return IntegerType.get_signless(width)
46
47
48def si(width):
49    return IntegerType.get_signed(width)
50
51
52def ui(width):
53    return IntegerType.get_unsigned(width)
54
55
56bool = lambda: i(1)
57i8 = lambda: i(8)
58i16 = lambda: i(16)
59i32 = lambda: i(32)
60i64 = lambda: i(64)
61
62si8 = lambda: si(8)
63si16 = lambda: si(16)
64si32 = lambda: si(32)
65si64 = lambda: si(64)
66
67ui8 = lambda: ui(8)
68ui16 = lambda: ui(16)
69ui32 = lambda: ui(32)
70ui64 = lambda: ui(64)
71
72f16 = lambda: F16Type.get()
73f32 = lambda: F32Type.get()
74tf32 = lambda: FloatTF32Type.get()
75f64 = lambda: F64Type.get()
76bf16 = lambda: BF16Type.get()
77
78f8E5M2 = lambda: Float8E5M2Type.get()
79f8E4M3 = lambda: Float8E4M3Type.get()
80f8E4M3FN = lambda: Float8E4M3FNType.get()
81f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
82f8E3M4 = lambda: Float8E3M4Type.get()
83f4E2M1FN = lambda: Float4E2M1FNType.get()
84f6E2M3FN = lambda: Float6E2M3FNType.get()
85f6E3M2FN = lambda: Float6E3M2FNType.get()
86f8E8M0FNU = lambda: Float8E8M0FNUType.get()
87
88none = lambda: NoneType.get()
89
90
91def complex(type):
92    return ComplexType.get(type)
93
94
95def opaque(dialect_namespace, type_data):
96    return OpaqueType.get(dialect_namespace, type_data)
97
98
99def _shaped(*shape, element_type: Type = None, type_constructor=None):
100    if type_constructor is None:
101        raise ValueError("shaped is an abstract base class - cannot be constructed.")
102    if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
103        shape and isinstance(shape[-1], Type) and element_type is not None
104    ):
105        raise ValueError(
106            f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
107        )
108    if element_type is not None:
109        type = element_type
110        sizes = shape
111    else:
112        type = shape[-1]
113        sizes = shape[:-1]
114    if sizes:
115        return type_constructor(sizes, type)
116    else:
117        return type_constructor(type)
118
119
120def vector(
121    *shape,
122    element_type: Type = None,
123    scalable: Optional[List[bool]] = None,
124    scalable_dims: Optional[List[int]] = None,
125):
126    return _shaped(
127        *shape,
128        element_type=element_type,
129        type_constructor=partial(
130            VectorType.get, scalable=scalable, scalable_dims=scalable_dims
131        ),
132    )
133
134
135def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
136    if encoding is not None:
137        encoding = StringAttr.get(encoding)
138    if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
139        if encoding is not None:
140            raise ValueError("UnrankedTensorType does not support encoding.")
141        return _shaped(
142            *shape, element_type=element_type, type_constructor=UnrankedTensorType.get
143        )
144    return _shaped(
145        *shape,
146        element_type=element_type,
147        type_constructor=partial(RankedTensorType.get, encoding=encoding),
148    )
149
150
151def memref(
152    *shape,
153    element_type: Type = None,
154    memory_space: Optional[int] = None,
155    layout: Optional[StridedLayoutAttr] = None,
156):
157    if memory_space is not None:
158        memory_space = Attribute.parse(str(memory_space))
159    if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
160        return _shaped(
161            *shape,
162            element_type=element_type,
163            type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
164        )
165    return _shaped(
166        *shape,
167        element_type=element_type,
168        type_constructor=partial(
169            MemRefType.get, memory_space=memory_space, layout=layout
170        ),
171    )
172
173
174def tuple(*elements):
175    return TupleType.get_tuple(elements)
176
177
178def function(*, inputs, results):
179    return FunctionType.get(inputs, results)
180