1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from functools import partial 6from typing import Optional, List 7 8from ..ir import ( 9 Attribute, 10 BF16Type, 11 ComplexType, 12 F16Type, 13 F32Type, 14 F64Type, 15 Float4E2M1FNType, 16 Float6E2M3FNType, 17 Float6E3M2FNType, 18 Float8E3M4Type, 19 Float8E4M3B11FNUZType, 20 Float8E4M3FNType, 21 Float8E4M3Type, 22 Float8E5M2Type, 23 Float8E8M0FNUType, 24 FloatTF32Type, 25 FunctionType, 26 IndexType, 27 IntegerType, 28 MemRefType, 29 NoneType, 30 OpaqueType, 31 RankedTensorType, 32 StridedLayoutAttr, 33 StringAttr, 34 TupleType, 35 Type, 36 UnrankedMemRefType, 37 UnrankedTensorType, 38 VectorType, 39) 40 41index = lambda: IndexType.get() 42 43 44def i(width): 45 return IntegerType.get_signless(width) 46 47 48def si(width): 49 return IntegerType.get_signed(width) 50 51 52def ui(width): 53 return IntegerType.get_unsigned(width) 54 55 56bool = lambda: i(1) 57i8 = lambda: i(8) 58i16 = lambda: i(16) 59i32 = lambda: i(32) 60i64 = lambda: i(64) 61 62si8 = lambda: si(8) 63si16 = lambda: si(16) 64si32 = lambda: si(32) 65si64 = lambda: si(64) 66 67ui8 = lambda: ui(8) 68ui16 = lambda: ui(16) 69ui32 = lambda: ui(32) 70ui64 = lambda: ui(64) 71 72f16 = lambda: F16Type.get() 73f32 = lambda: F32Type.get() 74tf32 = lambda: FloatTF32Type.get() 75f64 = lambda: F64Type.get() 76bf16 = lambda: BF16Type.get() 77 78f8E5M2 = lambda: Float8E5M2Type.get() 79f8E4M3 = lambda: Float8E4M3Type.get() 80f8E4M3FN = lambda: Float8E4M3FNType.get() 81f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get() 82f8E3M4 = lambda: Float8E3M4Type.get() 83f4E2M1FN = lambda: Float4E2M1FNType.get() 84f6E2M3FN = lambda: Float6E2M3FNType.get() 85f6E3M2FN = lambda: Float6E3M2FNType.get() 86f8E8M0FNU = lambda: Float8E8M0FNUType.get() 87 88none = lambda: NoneType.get() 89 90 91def complex(type): 92 return ComplexType.get(type) 93 94 95def opaque(dialect_namespace, type_data): 96 return OpaqueType.get(dialect_namespace, type_data) 97 98 99def _shaped(*shape, element_type: Type = None, type_constructor=None): 100 if type_constructor is None: 101 raise ValueError("shaped is an abstract base class - cannot be constructed.") 102 if (element_type is None and shape and not isinstance(shape[-1], Type)) or ( 103 shape and isinstance(shape[-1], Type) and element_type is not None 104 ): 105 raise ValueError( 106 f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type." 107 ) 108 if element_type is not None: 109 type = element_type 110 sizes = shape 111 else: 112 type = shape[-1] 113 sizes = shape[:-1] 114 if sizes: 115 return type_constructor(sizes, type) 116 else: 117 return type_constructor(type) 118 119 120def vector( 121 *shape, 122 element_type: Type = None, 123 scalable: Optional[List[bool]] = None, 124 scalable_dims: Optional[List[int]] = None, 125): 126 return _shaped( 127 *shape, 128 element_type=element_type, 129 type_constructor=partial( 130 VectorType.get, scalable=scalable, scalable_dims=scalable_dims 131 ), 132 ) 133 134 135def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None): 136 if encoding is not None: 137 encoding = StringAttr.get(encoding) 138 if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)): 139 if encoding is not None: 140 raise ValueError("UnrankedTensorType does not support encoding.") 141 return _shaped( 142 *shape, element_type=element_type, type_constructor=UnrankedTensorType.get 143 ) 144 return _shaped( 145 *shape, 146 element_type=element_type, 147 type_constructor=partial(RankedTensorType.get, encoding=encoding), 148 ) 149 150 151def memref( 152 *shape, 153 element_type: Type = None, 154 memory_space: Optional[int] = None, 155 layout: Optional[StridedLayoutAttr] = None, 156): 157 if memory_space is not None: 158 memory_space = Attribute.parse(str(memory_space)) 159 if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)): 160 return _shaped( 161 *shape, 162 element_type=element_type, 163 type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space), 164 ) 165 return _shaped( 166 *shape, 167 element_type=element_type, 168 type_constructor=partial( 169 MemRefType.get, memory_space=memory_space, layout=layout 170 ), 171 ) 172 173 174def tuple(*elements): 175 return TupleType.get_tuple(elements) 176 177 178def function(*, inputs, results): 179 return FunctionType.get(inputs, results) 180