1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from typing import Callable, Dict, List, Sequence, Tuple, Union 6 7from .....ir import * 8 9from .... import func 10from .... import linalg 11from .... import math 12from .... import arith 13from .... import complex 14from ...._ods_common import ( 15 get_op_result_or_value as _get_op_result_or_value, 16 get_op_results_or_values as _get_op_results_or_values, 17) 18 19from .scalar_expr import * 20from .config import * 21from .comprehension import * 22import numpy as np 23 24__all__ = [ 25 "emit_generic_structured_op", 26 "emit_named_structured_op", 27 "ValueList", 28] 29 30# Type aliases. 31ValueList = Union[Sequence[Value], OpResultList] 32 33 34def prepare_common_structured_op( 35 op_config: LinalgStructuredOpConfig, 36 *ins: Value, 37 outs: ValueList, 38 **attrs: Union[Sequence[int], TypeFnType], 39): 40 all_arg_defs = op_config.ordered_operands 41 in_arg_defs = [ 42 d 43 for d in all_arg_defs 44 if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR] 45 ] 46 out_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR] 47 index_attr_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR] 48 fn_attr_arg_defs = [ 49 d 50 for d in all_arg_defs 51 if d.kind 52 in [ 53 OperandKind.UNARY_FN_ATTR, 54 OperandKind.BINARY_FN_ATTR, 55 OperandKind.TERNARY_FN_ATTR, 56 OperandKind.TYPE_FN_ATTR, 57 ] 58 ] 59 60 # Verify outs is a sequence or a list of results. 61 if not isinstance(outs, (Sequence, OpResultList)): 62 raise ValueError( 63 f"Expected named argument outs to have type Sequence or " 64 f"OpResultLis but got {type(outs)}" 65 ) 66 67 # Arity validation. 68 if len(ins) != len(in_arg_defs): 69 raise ValueError( 70 f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}" 71 ) 72 if outs and len(outs) != len(out_arg_defs): 73 raise ValueError( 74 f"Expected {len(out_arg_defs)} outputs but got " 75 f"{len(outs)} for {op_config}" 76 ) 77 78 # Compute a replacement list for all index attribute symbols. 79 expressions = [] # type: Sequence[AffineExpr] 80 replacements = [] # type: Sequence[AffineExpr] 81 for index_attr in index_attr_arg_defs: 82 index_attr_vals = index_attr.operand_def.default_indices 83 if index_attr.name in attrs: 84 index_attr_vals = attrs.get(index_attr.name) 85 assert index_attr_vals, "Index attribute has no value" 86 if not all(isinstance(value, int) for value in index_attr_vals): 87 raise ValueError( 88 f"Attribute {index_attr.name} needs to be of type " 89 f"Sequence[int] but got {type(index_attr_vals)}" 90 ) 91 results = index_attr.index_attr_map.results # type: AffineExprList 92 if len(index_attr_vals) != len(results): 93 raise ValueError( 94 f"Attribute {index_attr.name} has length {len(results)} " 95 f"but got {len(index_attr_vals)} values" 96 ) 97 for expr, value in zip(results, index_attr_vals): 98 expressions.append(expr) 99 replacements.append(AffineConstantExpr.get(value)) 100 101 # Replace all index attribute symbols by their value. 102 # TODO: Add support for shape symbols. 103 indexing_maps = [] # type: Sequence[AffineMap] 104 for curr in op_config.indexing_maps: 105 for expression, replacement in zip(expressions, replacements): 106 curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols) 107 indexing_maps.append(curr) 108 109 # TODO: Linalg verification does not currently allow symbols. 110 # Compress them for now and verify none are left. 111 indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, Context.current) 112 if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps): 113 raise ValueError( 114 f"Expected indexing_maps to use no symbols after " 115 f"replacement and compression but got {indexing_maps}" 116 ) 117 118 outs, out_types = _infer_structured_outs( 119 op_config, in_arg_defs, ins, out_arg_defs, outs 120 ) 121 122 result_types = [t for t in out_types if isinstance(t, RankedTensorType)] 123 124 # Initialize the type dictionary with the predefined types. 125 type_mapping = dict() # type: Dict[str, Type] 126 type_mapping["F32"] = F32Type.get() 127 type_mapping["F64"] = F64Type.get() 128 type_mapping["I32"] = IntegerType.get_signless(32) 129 type_mapping["I64"] = IntegerType.get_signless(64) 130 131 # Extract type vars for input/output based types. 132 block_arg_types = list() # type: List[Type] 133 for arg_def, arg_element_type in zip( 134 in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs) 135 ): 136 _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types) 137 138 # Emit the generic op. 139 # TODO: Support emission of pure memref form. 140 indexing_maps_attr = ArrayAttr.get([AffineMapAttr.get(am) for am in indexing_maps]) 141 iterator_types_attr = ArrayAttr.get( 142 [ 143 Attribute.parse(f"#linalg.iterator_type<{s}>") 144 for s in op_config.iterator_types 145 ] 146 ) 147 148 # Compute the index attributes used when emitting a named structured op. 149 index_attrs = {} # type: Dict[str, DenseElementAttr] 150 for index_attr in index_attr_arg_defs: 151 index_attr_vals = attrs.get(index_attr.name) 152 # Only forward attributes set to a non-default value. 153 if index_attr_vals: 154 array = np.array(index_attr_vals, dtype=np.int64) 155 index_attrs[index_attr.name] = DenseElementsAttr.get(array) 156 157 # Compute the function attribute mapping. 158 fn_attr_mapping = {} 159 for fn_attr in fn_attr_arg_defs: 160 attr_val = fn_attr.operand_def.default_fn 161 attr_kind = fn_attr.kind 162 if fn_attr.name in attrs: 163 fn = attrs.get(fn_attr.name) 164 if attr_kind == OperandKind.UNARY_FN_ATTR: 165 if not isinstance(fn, UnaryFnType): 166 raise ValueError( 167 f"Attribute {fn_attr.name} needs to be of type " 168 f"UnaryFnType but got {type(attr_val)}" 169 ) 170 elif attr_kind == OperandKind.BINARY_FN_ATTR: 171 if not isinstance(fn, BinaryFnType): 172 raise ValueError( 173 f"Attribute {fn_attr.name} needs to be of type " 174 f"BinaryFnType but got {type(attr_val)}" 175 ) 176 elif attr_kind == OperandKind.TERNARY_FN_ATTR: 177 if not isinstance(fn, TernaryFnType): 178 raise ValueError( 179 f"Attribute {fn_attr.name} needs to be of type " 180 f"TernaryFnType but got {type(attr_val)}" 181 ) 182 else: 183 if not isinstance(fn, TypeFnType): 184 raise ValueError( 185 f"Attribute {fn_attr.name} needs to be of type " 186 f"TypeFnType but got {type(attr_val)}" 187 ) 188 attr_val = fn.fn_name 189 assert attr_val, "Function attribute has no value" 190 fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind) 191 192 return ( 193 all_arg_defs, 194 in_arg_defs, 195 out_arg_defs, 196 outs, 197 result_types, 198 type_mapping, 199 indexing_maps_attr, 200 iterator_types_attr, 201 index_attrs, 202 fn_attr_mapping, 203 block_arg_types, 204 ) 205 206 207def emit_generic_structured_op( 208 op_config: LinalgStructuredOpConfig, 209 *ins: Value, 210 outs: ValueList, 211 **attrs: Sequence[int], 212): 213 ( 214 all_arg_defs, 215 in_arg_defs, 216 out_arg_defs, 217 outs, 218 result_types, 219 type_mapping, 220 indexing_maps_attr, 221 iterator_types_attr, 222 index_attrs, 223 fn_attr_mapping, 224 block_arg_types, 225 ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) 226 227 # An operation that accesses only scalars and scalar/rank zero tensors is 228 # rank polymorhpic. We implement rank polymorphism by generating different 229 # indexing maps and iterators that match the rank of the first output tensor. 230 # An operation is rank polymorphic if the iteration domain has rank zero. 231 if not iterator_types_attr: 232 rank = ShapedType(outs[0].type).rank 233 iterator_types_attr = ArrayAttr.get( 234 [Attribute.parse("#linalg.iterator_type<parallel>")] * rank 235 ) 236 scalar_map = AffineMap.get(rank, 0, []) 237 tensor_map = AffineMap.get_identity(rank) 238 indexing_maps = [] 239 for arg_def in all_arg_defs: 240 if arg_def.operand_def.kind == OperandKind.SCALAR: 241 indexing_maps.append(scalar_map) 242 if arg_def.operand_def.is_tensor(): 243 idx = arg_def.operand_def.registered_index 244 if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: 245 indexing_maps.append(scalar_map) 246 else: 247 indexing_maps.append(tensor_map) 248 indexing_maps_attr = ArrayAttr.get( 249 [AffineMapAttr.get(am) for am in indexing_maps] 250 ) 251 252 generic_op = linalg.GenericOp( 253 result_tensors=result_types, 254 inputs=ins, 255 outputs=outs, 256 indexing_maps=indexing_maps_attr, 257 iterator_types=iterator_types_attr, 258 doc=None, # TODO: Make optional. 259 library_call=None, 260 ) # TODO: Make optional. 261 262 # Construct the body. 263 block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs) 264 block = generic_op.regions[0].blocks.append(*block_arg_types) 265 block_arg_mapping = dict(zip(block_arg_names, block.arguments)) 266 with InsertionPoint(block): 267 body_builder = _BodyBuilder(type_mapping, block_arg_mapping, fn_attr_mapping) 268 for assignment in op_config.assignments: 269 body_builder.assign(assignment) 270 body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) 271 272 if len(result_types) == 1: 273 return generic_op.result 274 else: 275 return generic_op.results 276 277 278def emit_named_structured_op( 279 op_config: LinalgStructuredOpConfig, 280 op_name: str, 281 op_class_name: str, 282 *ins: Value, 283 outs: ValueList, 284 **attrs: Sequence[int], 285): 286 ( 287 all_arg_defs, 288 in_arg_defs, 289 out_arg_defs, 290 outs, 291 result_types, 292 type_mapping, 293 indexing_maps_attr, 294 iterator_types_attr, 295 index_attrs, 296 fn_attr_mapping, 297 block_arg_types, 298 ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) 299 300 # If we get here, there must exist a builtin class `op_class_name`. 301 ctx = Context.current 302 fully_qualified_name = "linalg." + op_name 303 if ( 304 not ctx.is_registered_operation(fully_qualified_name) 305 or not op_class_name in linalg.__dict__.keys() 306 ): 307 raise NotImplementedError( 308 f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}" 309 ) 310 311 # Set the index attributes used to compute the indexing maps. 312 named_op = getattr(linalg, op_class_name)(result_types, ins, outs) 313 for name, value in index_attrs.items(): 314 named_op.operation.attributes[name] = value 315 316 # Compute the function attributes by combining operand kind and function name. 317 for name, (fn_name, kind) in fn_attr_mapping.items(): 318 assert kind.name.lower().endswith("_attr") 319 enum_name = kind.name.lower()[:-5] 320 named_op.operation.attributes[name] = Attribute.parse( 321 f"#linalg.{enum_name}<{fn_name}>" 322 ) 323 324 linalg.fill_builtin_region(named_op.operation) 325 326 if len(result_types) == 1: 327 return named_op.result 328 else: 329 return named_op.results 330 331 332class _BodyBuilder: 333 """Constructs a structured op body by evaluating assignments.""" 334 335 def __init__( 336 self, 337 type_mapping: Dict[str, Type], 338 block_arg_mapping: Dict[str, Value], 339 fn_attr_mapping: Dict[str, str], 340 ): 341 self.type_mapping = type_mapping 342 self.block_arg_mapping = block_arg_mapping 343 self.fn_attr_mapping = fn_attr_mapping 344 self.yield_mapping = dict() # type: Dict[str, Value] 345 346 def assign(self, assignment: ScalarAssign): 347 if assignment.arg in self.yield_mapping: 348 raise ValueError( 349 f"Multiple assignments to the same argument are forbidden: " 350 f"{assignment}" 351 ) 352 self.yield_mapping[assignment.arg] = self.expression(assignment.value) 353 354 def expression(self, expr: ScalarExpression) -> Value: 355 if expr.scalar_arg: 356 try: 357 return self.block_arg_mapping[expr.scalar_arg.arg] 358 except KeyError: 359 raise ValueError( 360 f"Argument {expr.scalar_arg.arg} is not bound for " 361 f"this structured op." 362 ) 363 elif expr.scalar_const: 364 value_attr = Attribute.parse(expr.scalar_const.value) 365 return arith.ConstantOp(value_attr.type, value_attr).result 366 elif expr.scalar_index: 367 dim_attr = IntegerAttr.get( 368 IntegerType.get_signless(64), expr.scalar_index.dim 369 ) 370 return linalg.IndexOp(dim_attr).result 371 elif expr.scalar_fn: 372 kind = expr.scalar_fn.kind.name.lower() 373 fn_name = expr.scalar_fn.fn_name 374 if expr.scalar_fn.attr_name: 375 fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name] 376 fn = self._get_function(f"_{kind}_{fn_name}") 377 operand_values = [ 378 self.expression(operand) for operand in expr.scalar_fn.operands 379 ] 380 if expr.scalar_fn.kind == FunctionKind.TYPE: 381 operand_values = [expr.scalar_fn.type_var.name] + operand_values 382 return fn(*operand_values) 383 raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") 384 385 def yield_outputs(self, *output_names: str): 386 output_values = [] 387 for n in output_names: 388 try: 389 output_values.append(self.yield_mapping[n]) 390 except KeyError: 391 raise ValueError( 392 f"Body assignments do not assign all outputs: " f"missing '{n}'" 393 ) 394 linalg.YieldOp(output_values) 395 396 def _get_function(self, fn_name: str) -> Callable: 397 try: 398 fn = getattr(self, f"{fn_name}") 399 except AttributeError: 400 raise ValueError(f"Function '{fn_name}' is not a known function") 401 return fn 402 403 def _cast( 404 self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False 405 ) -> Value: 406 try: 407 to_type = self.type_mapping[type_var_name] 408 except KeyError: 409 raise ValueError( 410 f"Unbound type variable '{type_var_name}' (" 411 f"expected one of {self.type_mapping.keys()}" 412 ) 413 if operand.type == to_type: 414 return operand 415 if _is_integer_type(to_type): 416 return self._cast_to_integer(to_type, operand, is_unsigned_cast) 417 elif _is_floating_point_type(to_type): 418 return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) 419 420 def _cast_to_integer( 421 self, to_type: Type, operand: Value, is_unsigned_cast: bool 422 ) -> Value: 423 to_width = IntegerType(to_type).width 424 operand_type = operand.type 425 if _is_floating_point_type(operand_type): 426 if is_unsigned_cast: 427 return arith.FPToUIOp(to_type, operand).result 428 return arith.FPToSIOp(to_type, operand).result 429 if _is_index_type(operand_type): 430 return arith.IndexCastOp(to_type, operand).result 431 # Assume integer. 432 from_width = IntegerType(operand_type).width 433 if to_width > from_width: 434 if is_unsigned_cast: 435 return arith.ExtUIOp(to_type, operand).result 436 return arith.ExtSIOp(to_type, operand).result 437 elif to_width < from_width: 438 return arith.TruncIOp(to_type, operand).result 439 raise ValueError( 440 f"Unable to cast body expression from {operand_type} to " f"{to_type}" 441 ) 442 443 def _cast_to_floating_point( 444 self, to_type: Type, operand: Value, is_unsigned_cast: bool 445 ) -> Value: 446 operand_type = operand.type 447 if _is_integer_type(operand_type): 448 if is_unsigned_cast: 449 return arith.UIToFPOp(to_type, operand).result 450 return arith.SIToFPOp(to_type, operand).result 451 # Assume FloatType. 452 to_width = _get_floating_point_width(to_type) 453 from_width = _get_floating_point_width(operand_type) 454 if to_width > from_width: 455 return arith.ExtFOp(to_type, operand).result 456 elif to_width < from_width: 457 return arith.TruncFOp(to_type, operand).result 458 raise ValueError( 459 f"Unable to cast body expression from {operand_type} to " f"{to_type}" 460 ) 461 462 def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value: 463 return self._cast(type_var_name, operand, False) 464 465 def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: 466 return self._cast(type_var_name, operand, True) 467 468 def _unary_exp(self, x: Value) -> Value: 469 if _is_floating_point_type(x.type): 470 return math.ExpOp(x).result 471 raise NotImplementedError("Unsupported 'exp' operand: {x}") 472 473 def _unary_log(self, x: Value) -> Value: 474 if _is_floating_point_type(x.type): 475 return math.LogOp(x).result 476 raise NotImplementedError("Unsupported 'log' operand: {x}") 477 478 def _unary_abs(self, x: Value) -> Value: 479 if _is_floating_point_type(x.type): 480 return math.AbsFOp(x).result 481 raise NotImplementedError("Unsupported 'abs' operand: {x}") 482 483 def _unary_ceil(self, x: Value) -> Value: 484 if _is_floating_point_type(x.type): 485 return math.CeilOp(x).result 486 raise NotImplementedError("Unsupported 'ceil' operand: {x}") 487 488 def _unary_floor(self, x: Value) -> Value: 489 if _is_floating_point_type(x.type): 490 return math.FloorOp(x).result 491 raise NotImplementedError("Unsupported 'floor' operand: {x}") 492 493 def _unary_negf(self, x: Value) -> Value: 494 if _is_floating_point_type(x.type): 495 return arith.NegFOp(x).result 496 if _is_complex_type(x.type): 497 return complex.NegOp(x).result 498 raise NotImplementedError("Unsupported 'negf' operand: {x}") 499 500 def _binary_add(self, lhs: Value, rhs: Value) -> Value: 501 if _is_floating_point_type(lhs.type): 502 return arith.AddFOp(lhs, rhs).result 503 if _is_integer_type(lhs.type) or _is_index_type(lhs.type): 504 return arith.AddIOp(lhs, rhs).result 505 if _is_complex_type(lhs.type): 506 return complex.AddOp(lhs, rhs).result 507 raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") 508 509 def _binary_sub(self, lhs: Value, rhs: Value) -> Value: 510 if _is_floating_point_type(lhs.type): 511 return arith.SubFOp(lhs, rhs).result 512 if _is_integer_type(lhs.type) or _is_index_type(lhs.type): 513 return arith.SubIOp(lhs, rhs).result 514 if _is_complex_type(lhs.type): 515 return complex.SubOp(lhs, rhs).result 516 raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") 517 518 def _binary_mul(self, lhs: Value, rhs: Value) -> Value: 519 if _is_floating_point_type(lhs.type): 520 return arith.MulFOp(lhs, rhs).result 521 if _is_integer_type(lhs.type) or _is_index_type(lhs.type): 522 return arith.MulIOp(lhs, rhs).result 523 if _is_complex_type(lhs.type): 524 return complex.MulOp(lhs, rhs).result 525 raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") 526 527 def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: 528 if _is_floating_point_type(lhs.type): 529 return arith.MaximumFOp(lhs, rhs).result 530 if _is_integer_type(lhs.type) or _is_index_type(lhs.type): 531 return arith.MaxSIOp(lhs, rhs).result 532 raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") 533 534 def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: 535 if _is_floating_point_type(lhs.type): 536 return arith.MaximumFOp(lhs, rhs).result 537 if _is_integer_type(lhs.type) or _is_index_type(lhs.type): 538 return arith.MaxUIOp(lhs, rhs).result 539 raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}") 540 541 def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: 542 if _is_floating_point_type(lhs.type): 543 return arith.MinimumFOp(lhs, rhs).result 544 if _is_integer_type(lhs.type) or _is_index_type(lhs.type): 545 return arith.MinSIOp(lhs, rhs).result 546 raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") 547 548 def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: 549 if _is_floating_point_type(lhs.type): 550 return arith.MinimumFOp(lhs, rhs).result 551 if _is_integer_type(lhs.type) or _is_index_type(lhs.type): 552 return arith.MinUIOp(lhs, rhs).result 553 raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}") 554 555 556def _infer_structured_outs( 557 op_config: LinalgStructuredOpConfig, 558 in_arg_defs: Sequence[OperandDefConfig], 559 ins: Sequence[Value], 560 out_arg_defs: Sequence[OperandDefConfig], 561 outs: Union[Sequence[Value], OpResultList], 562) -> Tuple[ValueList, List[Type]]: 563 """Infers implicit outs and output types. 564 565 Respects existing contents of outs if not empty. 566 567 Returns: 568 normalized outs, output types 569 """ 570 # If outs were explicitly provided, we accept them verbatim. 571 if outs: 572 return outs, [out.type for out in outs] 573 574 raise NotImplementedError( 575 f"Output tensor inference not yet supported for " "structured ops" 576 ) 577 578 579def _get_types_from_values(*values: Value) -> Sequence[Type]: 580 types = [] 581 for v in values: 582 types.append(v.type) 583 return types 584 585 586def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]: 587 return [odc.operand_def.name for odc in operand_configs] 588 589 590def _add_type_mapping( 591 operand_config: OperandDefConfig, 592 operand_type: Type, 593 type_mapping: Dict[str, Type], 594 block_arg_types: Sequence[Type], 595): 596 element_or_self_type = operand_type 597 # Get the element type for tensor operands and the type itself for scalars. 598 if operand_config.shape_map: 599 try: 600 element_or_self_type = ShapedType(operand_type).element_type 601 except Exception as e: 602 raise ValueError(f"Expected ShapedType but got {operand_type}") from e 603 name = operand_config.type_var.name 604 if name in type_mapping: 605 if type_mapping[name] != element_or_self_type: 606 raise ValueError( 607 f"Cannot overwrite type mapping {name} = " 608 f"{type_mapping[name]} by type {element_or_self_type}" 609 ) 610 type_mapping[name] = element_or_self_type 611 block_arg_types.append(element_or_self_type) 612 613 614def _is_complex_type(t: Type) -> bool: 615 return ComplexType.isinstance(t) 616 617 618def _is_floating_point_type(t: Type) -> bool: 619 # TODO: Create a FloatType in the Python API and implement the switch 620 # there. 621 return ( 622 F64Type.isinstance(t) 623 or F32Type.isinstance(t) 624 or F16Type.isinstance(t) 625 or BF16Type.isinstance(t) 626 ) 627 628 629def _is_integer_type(t: Type) -> bool: 630 return IntegerType.isinstance(t) 631 632 633def _is_index_type(t: Type) -> bool: 634 return IndexType.isinstance(t) 635 636 637def _get_floating_point_width(t: Type) -> int: 638 # TODO: Create a FloatType in the Python API and implement the switch 639 # there. 640 if F64Type.isinstance(t): 641 return 64 642 if F32Type.isinstance(t): 643 return 32 644 if F16Type.isinstance(t): 645 return 16 646 if BF16Type.isinstance(t): 647 return 16 648 raise NotImplementedError(f"Unhandled floating point type switch {t}") 649