xref: /llvm-project/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py (revision 04c94ba65f2123879332b6fbb851a60265e5c271)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from typing import Callable, Dict, List, Sequence, Tuple, Union
6
7from .....ir import *
8
9from .... import func
10from .... import linalg
11from .... import math
12from .... import arith
13from .... import complex
14from ...._ods_common import (
15    get_op_result_or_value as _get_op_result_or_value,
16    get_op_results_or_values as _get_op_results_or_values,
17)
18
19from .scalar_expr import *
20from .config import *
21from .comprehension import *
22import numpy as np
23
24__all__ = [
25    "emit_generic_structured_op",
26    "emit_named_structured_op",
27    "ValueList",
28]
29
30# Type aliases.
31ValueList = Union[Sequence[Value], OpResultList]
32
33
34def prepare_common_structured_op(
35    op_config: LinalgStructuredOpConfig,
36    *ins: Value,
37    outs: ValueList,
38    **attrs: Union[Sequence[int], TypeFnType],
39):
40    all_arg_defs = op_config.ordered_operands
41    in_arg_defs = [
42        d
43        for d in all_arg_defs
44        if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
45    ]
46    out_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR]
47    index_attr_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR]
48    fn_attr_arg_defs = [
49        d
50        for d in all_arg_defs
51        if d.kind
52        in [
53            OperandKind.UNARY_FN_ATTR,
54            OperandKind.BINARY_FN_ATTR,
55            OperandKind.TERNARY_FN_ATTR,
56            OperandKind.TYPE_FN_ATTR,
57        ]
58    ]
59
60    # Verify outs is a sequence or a list of results.
61    if not isinstance(outs, (Sequence, OpResultList)):
62        raise ValueError(
63            f"Expected named argument outs to have type Sequence or "
64            f"OpResultLis but got {type(outs)}"
65        )
66
67    # Arity validation.
68    if len(ins) != len(in_arg_defs):
69        raise ValueError(
70            f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}"
71        )
72    if outs and len(outs) != len(out_arg_defs):
73        raise ValueError(
74            f"Expected {len(out_arg_defs)} outputs but got "
75            f"{len(outs)} for {op_config}"
76        )
77
78    # Compute a replacement list for all index attribute symbols.
79    expressions = []  # type: Sequence[AffineExpr]
80    replacements = []  # type: Sequence[AffineExpr]
81    for index_attr in index_attr_arg_defs:
82        index_attr_vals = index_attr.operand_def.default_indices
83        if index_attr.name in attrs:
84            index_attr_vals = attrs.get(index_attr.name)
85        assert index_attr_vals, "Index attribute has no value"
86        if not all(isinstance(value, int) for value in index_attr_vals):
87            raise ValueError(
88                f"Attribute {index_attr.name} needs to be of type "
89                f"Sequence[int] but got {type(index_attr_vals)}"
90            )
91        results = index_attr.index_attr_map.results  # type: AffineExprList
92        if len(index_attr_vals) != len(results):
93            raise ValueError(
94                f"Attribute {index_attr.name} has length {len(results)} "
95                f"but got {len(index_attr_vals)} values"
96            )
97        for expr, value in zip(results, index_attr_vals):
98            expressions.append(expr)
99            replacements.append(AffineConstantExpr.get(value))
100
101    # Replace all index attribute symbols by their value.
102    # TODO: Add support for shape symbols.
103    indexing_maps = []  # type: Sequence[AffineMap]
104    for curr in op_config.indexing_maps:
105        for expression, replacement in zip(expressions, replacements):
106            curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
107        indexing_maps.append(curr)
108
109    # TODO: Linalg verification does not currently allow symbols.
110    # Compress them for now and verify none are left.
111    indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, Context.current)
112    if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
113        raise ValueError(
114            f"Expected indexing_maps to use no symbols after "
115            f"replacement and compression but got {indexing_maps}"
116        )
117
118    outs, out_types = _infer_structured_outs(
119        op_config, in_arg_defs, ins, out_arg_defs, outs
120    )
121
122    result_types = [t for t in out_types if isinstance(t, RankedTensorType)]
123
124    # Initialize the type dictionary with the predefined types.
125    type_mapping = dict()  # type: Dict[str, Type]
126    type_mapping["F32"] = F32Type.get()
127    type_mapping["F64"] = F64Type.get()
128    type_mapping["I32"] = IntegerType.get_signless(32)
129    type_mapping["I64"] = IntegerType.get_signless(64)
130
131    # Extract type vars for input/output based types.
132    block_arg_types = list()  # type: List[Type]
133    for arg_def, arg_element_type in zip(
134        in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs)
135    ):
136        _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types)
137
138    # Emit the generic op.
139    # TODO: Support emission of pure memref form.
140    indexing_maps_attr = ArrayAttr.get([AffineMapAttr.get(am) for am in indexing_maps])
141    iterator_types_attr = ArrayAttr.get(
142        [
143            Attribute.parse(f"#linalg.iterator_type<{s}>")
144            for s in op_config.iterator_types
145        ]
146    )
147
148    # Compute the index attributes used when emitting a named structured op.
149    index_attrs = {}  # type: Dict[str, DenseElementAttr]
150    for index_attr in index_attr_arg_defs:
151        index_attr_vals = attrs.get(index_attr.name)
152        # Only forward attributes set to a non-default value.
153        if index_attr_vals:
154            array = np.array(index_attr_vals, dtype=np.int64)
155            index_attrs[index_attr.name] = DenseElementsAttr.get(array)
156
157    # Compute the function attribute mapping.
158    fn_attr_mapping = {}
159    for fn_attr in fn_attr_arg_defs:
160        attr_val = fn_attr.operand_def.default_fn
161        attr_kind = fn_attr.kind
162        if fn_attr.name in attrs:
163            fn = attrs.get(fn_attr.name)
164            if attr_kind == OperandKind.UNARY_FN_ATTR:
165                if not isinstance(fn, UnaryFnType):
166                    raise ValueError(
167                        f"Attribute {fn_attr.name} needs to be of type "
168                        f"UnaryFnType but got {type(attr_val)}"
169                    )
170            elif attr_kind == OperandKind.BINARY_FN_ATTR:
171                if not isinstance(fn, BinaryFnType):
172                    raise ValueError(
173                        f"Attribute {fn_attr.name} needs to be of type "
174                        f"BinaryFnType but got {type(attr_val)}"
175                    )
176            elif attr_kind == OperandKind.TERNARY_FN_ATTR:
177                if not isinstance(fn, TernaryFnType):
178                    raise ValueError(
179                        f"Attribute {fn_attr.name} needs to be of type "
180                        f"TernaryFnType but got {type(attr_val)}"
181                    )
182            else:
183                if not isinstance(fn, TypeFnType):
184                    raise ValueError(
185                        f"Attribute {fn_attr.name} needs to be of type "
186                        f"TypeFnType but got {type(attr_val)}"
187                    )
188            attr_val = fn.fn_name
189        assert attr_val, "Function attribute has no value"
190        fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
191
192    return (
193        all_arg_defs,
194        in_arg_defs,
195        out_arg_defs,
196        outs,
197        result_types,
198        type_mapping,
199        indexing_maps_attr,
200        iterator_types_attr,
201        index_attrs,
202        fn_attr_mapping,
203        block_arg_types,
204    )
205
206
207def emit_generic_structured_op(
208    op_config: LinalgStructuredOpConfig,
209    *ins: Value,
210    outs: ValueList,
211    **attrs: Sequence[int],
212):
213    (
214        all_arg_defs,
215        in_arg_defs,
216        out_arg_defs,
217        outs,
218        result_types,
219        type_mapping,
220        indexing_maps_attr,
221        iterator_types_attr,
222        index_attrs,
223        fn_attr_mapping,
224        block_arg_types,
225    ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)
226
227    # An operation that accesses only scalars and scalar/rank zero tensors is
228    # rank polymorhpic. We implement rank polymorphism by generating different
229    # indexing maps and iterators that match the rank of the first output tensor.
230    # An operation is rank polymorphic if the iteration domain has rank zero.
231    if not iterator_types_attr:
232        rank = ShapedType(outs[0].type).rank
233        iterator_types_attr = ArrayAttr.get(
234            [Attribute.parse("#linalg.iterator_type<parallel>")] * rank
235        )
236        scalar_map = AffineMap.get(rank, 0, [])
237        tensor_map = AffineMap.get_identity(rank)
238        indexing_maps = []
239        for arg_def in all_arg_defs:
240            if arg_def.operand_def.kind == OperandKind.SCALAR:
241                indexing_maps.append(scalar_map)
242            if arg_def.operand_def.is_tensor():
243                idx = arg_def.operand_def.registered_index
244                if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
245                    indexing_maps.append(scalar_map)
246                else:
247                    indexing_maps.append(tensor_map)
248        indexing_maps_attr = ArrayAttr.get(
249            [AffineMapAttr.get(am) for am in indexing_maps]
250        )
251
252    generic_op = linalg.GenericOp(
253        result_tensors=result_types,
254        inputs=ins,
255        outputs=outs,
256        indexing_maps=indexing_maps_attr,
257        iterator_types=iterator_types_attr,
258        doc=None,  # TODO: Make optional.
259        library_call=None,
260    )  # TODO: Make optional.
261
262    # Construct the body.
263    block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs)
264    block = generic_op.regions[0].blocks.append(*block_arg_types)
265    block_arg_mapping = dict(zip(block_arg_names, block.arguments))
266    with InsertionPoint(block):
267        body_builder = _BodyBuilder(type_mapping, block_arg_mapping, fn_attr_mapping)
268        for assignment in op_config.assignments:
269            body_builder.assign(assignment)
270        body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
271
272    if len(result_types) == 1:
273        return generic_op.result
274    else:
275        return generic_op.results
276
277
278def emit_named_structured_op(
279    op_config: LinalgStructuredOpConfig,
280    op_name: str,
281    op_class_name: str,
282    *ins: Value,
283    outs: ValueList,
284    **attrs: Sequence[int],
285):
286    (
287        all_arg_defs,
288        in_arg_defs,
289        out_arg_defs,
290        outs,
291        result_types,
292        type_mapping,
293        indexing_maps_attr,
294        iterator_types_attr,
295        index_attrs,
296        fn_attr_mapping,
297        block_arg_types,
298    ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)
299
300    # If we get here, there must exist a builtin class `op_class_name`.
301    ctx = Context.current
302    fully_qualified_name = "linalg." + op_name
303    if (
304        not ctx.is_registered_operation(fully_qualified_name)
305        or not op_class_name in linalg.__dict__.keys()
306    ):
307        raise NotImplementedError(
308            f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}"
309        )
310
311    # Set the index attributes used to compute the indexing maps.
312    named_op = getattr(linalg, op_class_name)(result_types, ins, outs)
313    for name, value in index_attrs.items():
314        named_op.operation.attributes[name] = value
315
316    # Compute the function attributes by combining operand kind and function name.
317    for name, (fn_name, kind) in fn_attr_mapping.items():
318        assert kind.name.lower().endswith("_attr")
319        enum_name = kind.name.lower()[:-5]
320        named_op.operation.attributes[name] = Attribute.parse(
321            f"#linalg.{enum_name}<{fn_name}>"
322        )
323
324    linalg.fill_builtin_region(named_op.operation)
325
326    if len(result_types) == 1:
327        return named_op.result
328    else:
329        return named_op.results
330
331
332class _BodyBuilder:
333    """Constructs a structured op body by evaluating assignments."""
334
335    def __init__(
336        self,
337        type_mapping: Dict[str, Type],
338        block_arg_mapping: Dict[str, Value],
339        fn_attr_mapping: Dict[str, str],
340    ):
341        self.type_mapping = type_mapping
342        self.block_arg_mapping = block_arg_mapping
343        self.fn_attr_mapping = fn_attr_mapping
344        self.yield_mapping = dict()  # type: Dict[str, Value]
345
346    def assign(self, assignment: ScalarAssign):
347        if assignment.arg in self.yield_mapping:
348            raise ValueError(
349                f"Multiple assignments to the same argument are forbidden: "
350                f"{assignment}"
351            )
352        self.yield_mapping[assignment.arg] = self.expression(assignment.value)
353
354    def expression(self, expr: ScalarExpression) -> Value:
355        if expr.scalar_arg:
356            try:
357                return self.block_arg_mapping[expr.scalar_arg.arg]
358            except KeyError:
359                raise ValueError(
360                    f"Argument {expr.scalar_arg.arg} is not bound for "
361                    f"this structured op."
362                )
363        elif expr.scalar_const:
364            value_attr = Attribute.parse(expr.scalar_const.value)
365            return arith.ConstantOp(value_attr.type, value_attr).result
366        elif expr.scalar_index:
367            dim_attr = IntegerAttr.get(
368                IntegerType.get_signless(64), expr.scalar_index.dim
369            )
370            return linalg.IndexOp(dim_attr).result
371        elif expr.scalar_fn:
372            kind = expr.scalar_fn.kind.name.lower()
373            fn_name = expr.scalar_fn.fn_name
374            if expr.scalar_fn.attr_name:
375                fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
376            fn = self._get_function(f"_{kind}_{fn_name}")
377            operand_values = [
378                self.expression(operand) for operand in expr.scalar_fn.operands
379            ]
380            if expr.scalar_fn.kind == FunctionKind.TYPE:
381                operand_values = [expr.scalar_fn.type_var.name] + operand_values
382            return fn(*operand_values)
383        raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
384
385    def yield_outputs(self, *output_names: str):
386        output_values = []
387        for n in output_names:
388            try:
389                output_values.append(self.yield_mapping[n])
390            except KeyError:
391                raise ValueError(
392                    f"Body assignments do not assign all outputs: " f"missing '{n}'"
393                )
394        linalg.YieldOp(output_values)
395
396    def _get_function(self, fn_name: str) -> Callable:
397        try:
398            fn = getattr(self, f"{fn_name}")
399        except AttributeError:
400            raise ValueError(f"Function '{fn_name}' is not a known function")
401        return fn
402
403    def _cast(
404        self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False
405    ) -> Value:
406        try:
407            to_type = self.type_mapping[type_var_name]
408        except KeyError:
409            raise ValueError(
410                f"Unbound type variable '{type_var_name}' ("
411                f"expected one of {self.type_mapping.keys()}"
412            )
413        if operand.type == to_type:
414            return operand
415        if _is_integer_type(to_type):
416            return self._cast_to_integer(to_type, operand, is_unsigned_cast)
417        elif _is_floating_point_type(to_type):
418            return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
419
420    def _cast_to_integer(
421        self, to_type: Type, operand: Value, is_unsigned_cast: bool
422    ) -> Value:
423        to_width = IntegerType(to_type).width
424        operand_type = operand.type
425        if _is_floating_point_type(operand_type):
426            if is_unsigned_cast:
427                return arith.FPToUIOp(to_type, operand).result
428            return arith.FPToSIOp(to_type, operand).result
429        if _is_index_type(operand_type):
430            return arith.IndexCastOp(to_type, operand).result
431        # Assume integer.
432        from_width = IntegerType(operand_type).width
433        if to_width > from_width:
434            if is_unsigned_cast:
435                return arith.ExtUIOp(to_type, operand).result
436            return arith.ExtSIOp(to_type, operand).result
437        elif to_width < from_width:
438            return arith.TruncIOp(to_type, operand).result
439        raise ValueError(
440            f"Unable to cast body expression from {operand_type} to " f"{to_type}"
441        )
442
443    def _cast_to_floating_point(
444        self, to_type: Type, operand: Value, is_unsigned_cast: bool
445    ) -> Value:
446        operand_type = operand.type
447        if _is_integer_type(operand_type):
448            if is_unsigned_cast:
449                return arith.UIToFPOp(to_type, operand).result
450            return arith.SIToFPOp(to_type, operand).result
451        # Assume FloatType.
452        to_width = _get_floating_point_width(to_type)
453        from_width = _get_floating_point_width(operand_type)
454        if to_width > from_width:
455            return arith.ExtFOp(to_type, operand).result
456        elif to_width < from_width:
457            return arith.TruncFOp(to_type, operand).result
458        raise ValueError(
459            f"Unable to cast body expression from {operand_type} to " f"{to_type}"
460        )
461
462    def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
463        return self._cast(type_var_name, operand, False)
464
465    def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
466        return self._cast(type_var_name, operand, True)
467
468    def _unary_exp(self, x: Value) -> Value:
469        if _is_floating_point_type(x.type):
470            return math.ExpOp(x).result
471        raise NotImplementedError("Unsupported 'exp' operand: {x}")
472
473    def _unary_log(self, x: Value) -> Value:
474        if _is_floating_point_type(x.type):
475            return math.LogOp(x).result
476        raise NotImplementedError("Unsupported 'log' operand: {x}")
477
478    def _unary_abs(self, x: Value) -> Value:
479        if _is_floating_point_type(x.type):
480            return math.AbsFOp(x).result
481        raise NotImplementedError("Unsupported 'abs' operand: {x}")
482
483    def _unary_ceil(self, x: Value) -> Value:
484        if _is_floating_point_type(x.type):
485            return math.CeilOp(x).result
486        raise NotImplementedError("Unsupported 'ceil' operand: {x}")
487
488    def _unary_floor(self, x: Value) -> Value:
489        if _is_floating_point_type(x.type):
490            return math.FloorOp(x).result
491        raise NotImplementedError("Unsupported 'floor' operand: {x}")
492
493    def _unary_negf(self, x: Value) -> Value:
494        if _is_floating_point_type(x.type):
495            return arith.NegFOp(x).result
496        if _is_complex_type(x.type):
497            return complex.NegOp(x).result
498        raise NotImplementedError("Unsupported 'negf' operand: {x}")
499
500    def _binary_add(self, lhs: Value, rhs: Value) -> Value:
501        if _is_floating_point_type(lhs.type):
502            return arith.AddFOp(lhs, rhs).result
503        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
504            return arith.AddIOp(lhs, rhs).result
505        if _is_complex_type(lhs.type):
506            return complex.AddOp(lhs, rhs).result
507        raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
508
509    def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
510        if _is_floating_point_type(lhs.type):
511            return arith.SubFOp(lhs, rhs).result
512        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
513            return arith.SubIOp(lhs, rhs).result
514        if _is_complex_type(lhs.type):
515            return complex.SubOp(lhs, rhs).result
516        raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
517
518    def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
519        if _is_floating_point_type(lhs.type):
520            return arith.MulFOp(lhs, rhs).result
521        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
522            return arith.MulIOp(lhs, rhs).result
523        if _is_complex_type(lhs.type):
524            return complex.MulOp(lhs, rhs).result
525        raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
526
527    def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
528        if _is_floating_point_type(lhs.type):
529            return arith.MaximumFOp(lhs, rhs).result
530        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
531            return arith.MaxSIOp(lhs, rhs).result
532        raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
533
534    def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
535        if _is_floating_point_type(lhs.type):
536            return arith.MaximumFOp(lhs, rhs).result
537        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
538            return arith.MaxUIOp(lhs, rhs).result
539        raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
540
541    def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
542        if _is_floating_point_type(lhs.type):
543            return arith.MinimumFOp(lhs, rhs).result
544        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
545            return arith.MinSIOp(lhs, rhs).result
546        raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
547
548    def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
549        if _is_floating_point_type(lhs.type):
550            return arith.MinimumFOp(lhs, rhs).result
551        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
552            return arith.MinUIOp(lhs, rhs).result
553        raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
554
555
556def _infer_structured_outs(
557    op_config: LinalgStructuredOpConfig,
558    in_arg_defs: Sequence[OperandDefConfig],
559    ins: Sequence[Value],
560    out_arg_defs: Sequence[OperandDefConfig],
561    outs: Union[Sequence[Value], OpResultList],
562) -> Tuple[ValueList, List[Type]]:
563    """Infers implicit outs and output types.
564
565    Respects existing contents of outs if not empty.
566
567    Returns:
568      normalized outs, output types
569    """
570    # If outs were explicitly provided, we accept them verbatim.
571    if outs:
572        return outs, [out.type for out in outs]
573
574    raise NotImplementedError(
575        f"Output tensor inference not yet supported for " "structured ops"
576    )
577
578
579def _get_types_from_values(*values: Value) -> Sequence[Type]:
580    types = []
581    for v in values:
582        types.append(v.type)
583    return types
584
585
586def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]:
587    return [odc.operand_def.name for odc in operand_configs]
588
589
590def _add_type_mapping(
591    operand_config: OperandDefConfig,
592    operand_type: Type,
593    type_mapping: Dict[str, Type],
594    block_arg_types: Sequence[Type],
595):
596    element_or_self_type = operand_type
597    # Get the element type for tensor operands and the type itself for scalars.
598    if operand_config.shape_map:
599        try:
600            element_or_self_type = ShapedType(operand_type).element_type
601        except Exception as e:
602            raise ValueError(f"Expected ShapedType but got {operand_type}") from e
603    name = operand_config.type_var.name
604    if name in type_mapping:
605        if type_mapping[name] != element_or_self_type:
606            raise ValueError(
607                f"Cannot overwrite type mapping {name} = "
608                f"{type_mapping[name]} by type {element_or_self_type}"
609            )
610    type_mapping[name] = element_or_self_type
611    block_arg_types.append(element_or_self_type)
612
613
614def _is_complex_type(t: Type) -> bool:
615    return ComplexType.isinstance(t)
616
617
618def _is_floating_point_type(t: Type) -> bool:
619    # TODO: Create a FloatType in the Python API and implement the switch
620    # there.
621    return (
622        F64Type.isinstance(t)
623        or F32Type.isinstance(t)
624        or F16Type.isinstance(t)
625        or BF16Type.isinstance(t)
626    )
627
628
629def _is_integer_type(t: Type) -> bool:
630    return IntegerType.isinstance(t)
631
632
633def _is_index_type(t: Type) -> bool:
634    return IndexType.isinstance(t)
635
636
637def _get_floating_point_width(t: Type) -> int:
638    # TODO: Create a FloatType in the Python API and implement the switch
639    # there.
640    if F64Type.isinstance(t):
641        return 64
642    if F32Type.isinstance(t):
643        return 32
644    if F16Type.isinstance(t):
645        return 16
646    if BF16Type.isinstance(t):
647        return 16
648    raise NotImplementedError(f"Unhandled floating point type switch {t}")
649