1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from ._arith_ops_gen import * 6from ._arith_ops_gen import _Dialect 7from ._arith_enum_gen import * 8from array import array as _array 9from typing import overload 10 11try: 12 from ..ir import * 13 from ._ods_common import ( 14 get_default_loc_context as _get_default_loc_context, 15 _cext as _ods_cext, 16 get_op_result_or_op_results as _get_op_result_or_op_results, 17 ) 18 19 from typing import Any, List, Union 20except ImportError as e: 21 raise RuntimeError("Error loading imports from extension module") from e 22 23 24def _isa(obj: Any, cls: type): 25 try: 26 cls(obj) 27 except ValueError: 28 return False 29 return True 30 31 32def _is_any_of(obj: Any, classes: List[type]): 33 return any(_isa(obj, cls) for cls in classes) 34 35 36def _is_integer_like_type(type: Type): 37 return _is_any_of(type, [IntegerType, IndexType]) 38 39 40def _is_float_type(type: Type): 41 return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) 42 43 44@_ods_cext.register_operation(_Dialect, replace=True) 45class ConstantOp(ConstantOp): 46 """Specialization for the constant op class.""" 47 48 @overload 49 def __init__(self, value: Attribute, *, loc=None, ip=None): 50 ... 51 52 @overload 53 def __init__( 54 self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None 55 ): 56 ... 57 58 def __init__(self, result, value, *, loc=None, ip=None): 59 if value is None: 60 assert isinstance(result, Attribute) 61 super().__init__(result, loc=loc, ip=ip) 62 return 63 64 if isinstance(value, int): 65 super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) 66 elif isinstance(value, float): 67 super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) 68 elif isinstance(value, _array): 69 if 8 * value.itemsize != result.element_type.width: 70 raise ValueError( 71 f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width." 72 ) 73 if value.typecode in ["i", "l", "q"]: 74 super().__init__(DenseIntElementsAttr.get(value, type=result)) 75 elif value.typecode in ["f", "d"]: 76 super().__init__(DenseFPElementsAttr.get(value, type=result)) 77 else: 78 raise ValueError(f'Unsupported typecode: "{value.typecode}".') 79 else: 80 super().__init__(value, loc=loc, ip=ip) 81 82 @classmethod 83 def create_index(cls, value: int, *, loc=None, ip=None): 84 """Create an index-typed constant.""" 85 return cls( 86 IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip 87 ) 88 89 @property 90 def type(self): 91 return self.results[0].type 92 93 @property 94 def value(self): 95 return Attribute(self.operation.attributes["value"]) 96 97 @property 98 def literal_value(self) -> Union[int, float]: 99 if _is_integer_like_type(self.type): 100 return IntegerAttr(self.value).value 101 elif _is_float_type(self.type): 102 return FloatAttr(self.value).value 103 else: 104 raise ValueError("only integer and float constants have literal values") 105 106 107def constant( 108 result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None 109) -> Value: 110 return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) 111