xref: /llvm-project/mlir/python/mlir/dialects/arith.py (revision 5d59fa90ce225814739d9b51ba37e1cca9204cad)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from ._arith_ops_gen import *
6from ._arith_ops_gen import _Dialect
7from ._arith_enum_gen import *
8from array import array as _array
9from typing import overload
10
11try:
12    from ..ir import *
13    from ._ods_common import (
14        get_default_loc_context as _get_default_loc_context,
15        _cext as _ods_cext,
16        get_op_result_or_op_results as _get_op_result_or_op_results,
17    )
18
19    from typing import Any, List, Union
20except ImportError as e:
21    raise RuntimeError("Error loading imports from extension module") from e
22
23
24def _isa(obj: Any, cls: type):
25    try:
26        cls(obj)
27    except ValueError:
28        return False
29    return True
30
31
32def _is_any_of(obj: Any, classes: List[type]):
33    return any(_isa(obj, cls) for cls in classes)
34
35
36def _is_integer_like_type(type: Type):
37    return _is_any_of(type, [IntegerType, IndexType])
38
39
40def _is_float_type(type: Type):
41    return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
42
43
44@_ods_cext.register_operation(_Dialect, replace=True)
45class ConstantOp(ConstantOp):
46    """Specialization for the constant op class."""
47
48    @overload
49    def __init__(self, value: Attribute, *, loc=None, ip=None):
50        ...
51
52    @overload
53    def __init__(
54        self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
55    ):
56        ...
57
58    def __init__(self, result, value, *, loc=None, ip=None):
59        if value is None:
60            assert isinstance(result, Attribute)
61            super().__init__(result, loc=loc, ip=ip)
62            return
63
64        if isinstance(value, int):
65            super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
66        elif isinstance(value, float):
67            super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
68        elif isinstance(value, _array):
69            if 8 * value.itemsize != result.element_type.width:
70                raise ValueError(
71                    f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width."
72                )
73            if value.typecode in ["i", "l", "q"]:
74                super().__init__(DenseIntElementsAttr.get(value, type=result))
75            elif value.typecode in ["f", "d"]:
76                super().__init__(DenseFPElementsAttr.get(value, type=result))
77            else:
78                raise ValueError(f'Unsupported typecode: "{value.typecode}".')
79        else:
80            super().__init__(value, loc=loc, ip=ip)
81
82    @classmethod
83    def create_index(cls, value: int, *, loc=None, ip=None):
84        """Create an index-typed constant."""
85        return cls(
86            IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
87        )
88
89    @property
90    def type(self):
91        return self.results[0].type
92
93    @property
94    def value(self):
95        return Attribute(self.operation.attributes["value"])
96
97    @property
98    def literal_value(self) -> Union[int, float]:
99        if _is_integer_like_type(self.type):
100            return IntegerAttr(self.value).value
101        elif _is_float_type(self.type):
102            return FloatAttr(self.value).value
103        else:
104            raise ValueError("only integer and float constants have literal values")
105
106
107def constant(
108    result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
109) -> Value:
110    return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
111