1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from .._structured_transform_ops_gen import * 6from .._structured_transform_ops_gen import _Dialect 7from .._structured_transform_enum_gen import * 8 9try: 10 from ...ir import * 11 from ...dialects import transform 12 from .._ods_common import ( 13 DynamicIndexList, 14 IntOrAttrList, 15 MixedValues, 16 OptionalBoolList, 17 OptionalIntList, 18 _cext as _ods_cext, 19 _dispatch_dynamic_index_list, 20 _dispatch_mixed_values, 21 _get_int_array_array_attr, 22 _get_int_array_attr, 23 _get_value_list, 24 _get_value_or_attribute_value, 25 ) 26except ImportError as e: 27 raise RuntimeError("Error loading imports from extension module") from e 28 29from typing import List, Optional, Sequence, Union, overload 30 31 32@_ods_cext.register_operation(_Dialect, replace=True) 33class BufferizeToAllocationOp(BufferizeToAllocationOp): 34 """Specialization for BufferizeToAllocationOp class.""" 35 36 def __init__( 37 self, 38 target: Union[Operation, OpView, Value], 39 *, 40 memory_space: Optional[Union[int, str, Attribute]] = None, 41 memcpy_op: Optional[str] = None, 42 alloc_op: Optional[str] = None, 43 bufferize_destination_only: Optional[bool] = None, 44 loc=None, 45 ip=None, 46 ): 47 # No other types are allowed, so hard-code those here. 48 allocated_buffer_type = transform.AnyValueType.get() 49 new_ops_type = transform.AnyOpType.get() 50 51 if isinstance(memory_space, int): 52 memory_space = str(memory_space) 53 if isinstance(memory_space, str): 54 memory_space = Attribute.parse(memory_space) 55 56 super().__init__( 57 allocated_buffer_type, 58 new_ops_type, 59 target, 60 memory_space=memory_space, 61 memcpy_op=memcpy_op, 62 alloc_op=alloc_op, 63 bufferize_destination_only=bufferize_destination_only, 64 loc=loc, 65 ip=ip, 66 ) 67 68 69@_ods_cext.register_operation(_Dialect, replace=True) 70class DecomposeOp(DecomposeOp): 71 """Specialization for DecomposeOp class.""" 72 73 def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): 74 transformed_type = transform.AnyOpType.get() 75 super().__init__(transformed_type, target, loc=loc, ip=ip) 76 77 78@_ods_cext.register_operation(_Dialect, replace=True) 79class FuseIntoContainingOp(FuseIntoContainingOp): 80 """Specialization for FuseIntoContainingOp class.""" 81 82 @overload 83 def __init__( 84 self, 85 fused_op_type: Type, 86 new_containing_op_type: Type, 87 producer_op: Union[Operation, OpView, Value], 88 containing_op: Union[Operation, OpView, Value], 89 *, 90 loc=None, 91 ip=None, 92 ): 93 ... 94 95 @overload 96 def __init__( 97 self, 98 producer_op: Union[Operation, OpView, Value], 99 containing_op: Union[Operation, OpView, Value], 100 *, 101 loc=None, 102 ip=None, 103 ): 104 ... 105 106 def __init__( 107 self, 108 fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value], 109 new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value], 110 producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None, 111 containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None, 112 *, 113 loc=None, 114 ip=None, 115 ): 116 if isinstance(fused_op_type_or_producer_op, Type): 117 if not isinstance(new_containing_op_type_or_containing_op, Type): 118 raise TypeError( 119 "If 'fused_op_type_or_producer_op' is a type, then " 120 "'new_containing_op_type_or_containing_op' is expected " 121 "to be one as well." 122 ) 123 fused_op_type = fused_op_type_or_producer_op 124 new_containing_op_type = new_containing_op_type_or_containing_op 125 producer_op = producer_op_or_none 126 containing_op = containing_op_or_none 127 else: 128 fused_op_type = transform.AnyOpType.get() 129 new_containing_op_type = transform.AnyOpType.get() 130 producer_op = fused_op_type_or_producer_op 131 containing_op = new_containing_op_type_or_containing_op 132 133 super().__init__( 134 fused_op_type, 135 new_containing_op_type, 136 producer_op, 137 containing_op, 138 loc=loc, 139 ip=ip, 140 ) 141 142 143@_ods_cext.register_operation(_Dialect, replace=True) 144class FuseOp(FuseOp): 145 """Specialization for FuseOp class.""" 146 147 @overload 148 def __init__( 149 self, 150 loop_types: Union[Type, Sequence[Type]], 151 target: Union[Operation, Value, OpView], 152 *, 153 tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, 154 tile_interchange: OptionalIntList = None, 155 apply_cleanup: Optional[bool] = False, 156 loc=None, 157 ip=None, 158 ): 159 ... 160 161 @overload 162 def __init__( 163 self, 164 target: Union[Operation, Value, OpView], 165 *, 166 tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, 167 tile_interchange: OptionalIntList = None, 168 apply_cleanup: Optional[bool] = False, 169 loc=None, 170 ip=None, 171 ): 172 ... 173 174 def __init__( 175 self, 176 loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value], 177 target_or_none: Optional[Union[Operation, Value, OpView]] = None, 178 *, 179 tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, 180 tile_interchange: OptionalIntList = None, 181 apply_cleanup: Optional[bool] = False, 182 loc=None, 183 ip=None, 184 ): 185 tile_sizes = tile_sizes if tile_sizes else [] 186 tile_interchange = tile_interchange if tile_interchange else [] 187 _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes) 188 _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange) 189 num_loops = sum(0 if v == 0 else 1 for v in tile_sizes) 190 191 if isinstance(loop_types_or_target, (Operation, Value, OpView)): 192 loop_types = [transform.AnyOpType.get()] * num_loops 193 target = loop_types_or_target 194 assert target_or_none is None, "Cannot construct FuseOp with two targets." 195 else: 196 loop_types = ( 197 ([loop_types_or_target] * num_loops) 198 if isinstance(loop_types_or_target, Type) 199 else loop_types_or_target 200 ) 201 target = target_or_none 202 super().__init__( 203 target.type, 204 loop_types, 205 target, 206 tile_sizes=tile_sizes, 207 tile_interchange=tile_interchange, 208 apply_cleanup=apply_cleanup, 209 loc=loc, 210 ip=ip, 211 ) 212 213 214@_ods_cext.register_operation(_Dialect, replace=True) 215class GeneralizeOp(GeneralizeOp): 216 """Specialization for GeneralizeOp class.""" 217 218 def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): 219 transformed_type = transform.AnyOpType.get() 220 super().__init__(transformed_type, target, loc=loc, ip=ip) 221 222 223@_ods_cext.register_operation(_Dialect, replace=True) 224class InterchangeOp(InterchangeOp): 225 """Specialization for InterchangeOp class.""" 226 227 def __init__( 228 self, 229 target: Union[Operation, Value], 230 *, 231 iterator_interchange: OptionalIntList = None, 232 loc=None, 233 ip=None, 234 ): 235 transformed_type = transform.AnyOpType.get() 236 super().__init__( 237 transformed_type, 238 target, 239 iterator_interchange=iterator_interchange, 240 loc=loc, 241 ip=ip, 242 ) 243 244 245@_ods_cext.register_operation(_Dialect, replace=True) 246class MapCopyToThreadsOp(MapCopyToThreadsOp): 247 """Specialization for MapCopyToThreadsOp class.""" 248 249 @overload 250 def __init__( 251 self, 252 forall_op_type: Type, 253 tiled_op_type: Type, 254 target: Union[Operation, OpView, Value], 255 *, 256 total_num_threads: Union[int, IntegerAttr], 257 desired_bit_alignment: Union[int, IntegerAttr], 258 loc=None, 259 ip=None, 260 ): 261 ... 262 263 @overload 264 def __init__( 265 self, 266 target: Union[Operation, OpView, Value], 267 *, 268 total_num_threads: Union[int, IntegerAttr], 269 desired_bit_alignment: Union[int, IntegerAttr], 270 loc=None, 271 ip=None, 272 ): 273 ... 274 275 def __init__( 276 self, 277 forall_op_type_or_target: Union[Operation, OpView, Type, Value], 278 tiled_op_type_or_none: Optional[Type] = None, 279 target_or_none: Optional[Union[Operation, OpView, Value]] = None, 280 *, 281 total_num_threads: Union[int, IntegerAttr], 282 desired_bit_alignment: Union[int, IntegerAttr], 283 loc=None, 284 ip=None, 285 ): 286 if isinstance(forall_op_type_or_target, Type): 287 forall_op_type = forall_op_type_or_target 288 tiled_op_type = tiled_op_type_or_none 289 target = target_or_none 290 else: 291 forall_op_type = transform.AnyOpType.get() 292 tiled_op_type = transform.AnyOpType.get() 293 target = forall_op_type_or_target 294 295 super().__init__( 296 forall_op_type, 297 tiled_op_type, 298 target, 299 total_num_threads=total_num_threads, 300 desired_bit_alignment=desired_bit_alignment, 301 loc=loc, 302 ip=ip, 303 ) 304 305 306@_ods_cext.register_operation(_Dialect, replace=True) 307class VectorizeOp(VectorizeOp): 308 """Specialization for VectorizeOp class.""" 309 310 def __init__( 311 self, 312 target: Union[Operation, OpView, Value], 313 vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, 314 *, 315 vectorize_nd_extract: Optional[bool] = None, 316 scalable_sizes: OptionalBoolList = None, 317 static_vector_sizes: OptionalIntList = None, 318 loc=None, 319 ip=None, 320 ): 321 if ( 322 scalable_sizes is None 323 and static_vector_sizes is None 324 and vector_sizes is None 325 ): 326 dynamic_vector_sizes = [] 327 elif scalable_sizes is None and static_vector_sizes is None: 328 ( 329 dynamic_vector_sizes, 330 static_vector_sizes, 331 scalable_sizes, 332 ) = _dispatch_dynamic_index_list(vector_sizes) 333 elif scalable_sizes is None or static_vector_sizes is None: 334 raise TypeError( 335 "'scalable_sizes' and 'static_vector_sizes' must either both " 336 "be given explicitly or both be given as part of 'vector_sizes'." 337 ) 338 else: 339 dynamic_vector_sizes = vector_sizes 340 341 super().__init__( 342 target, 343 vector_sizes=dynamic_vector_sizes, 344 static_vector_sizes=static_vector_sizes, 345 scalable_sizes=scalable_sizes, 346 vectorize_nd_extract=vectorize_nd_extract, 347 loc=loc, 348 ip=ip, 349 ) 350 351 352@_ods_cext.register_operation(_Dialect, replace=True) 353class MatchOp(MatchOp): 354 """Specialization for MatchOp class.""" 355 356 @overload 357 @classmethod 358 def match_op_names( 359 cls, 360 target: Union[Operation, Value], 361 names: Union[str, Sequence[str]], 362 *, 363 loc=None, 364 ip=None, 365 ): 366 ... 367 368 @overload 369 @classmethod 370 def match_op_names( 371 cls, 372 result_type: Type, 373 target: Union[Operation, Value], 374 names: Union[str, Sequence[str]], 375 *, 376 loc=None, 377 ip=None, 378 ): 379 ... 380 381 @classmethod 382 def match_op_names( 383 cls, 384 result_type_or_target: Union[Type, Operation, Value], 385 target_or_names: Union[Operation, Value, Sequence[str], str], 386 names_or_none: Optional[Union[Sequence[str], str]] = None, 387 *, 388 loc=None, 389 ip=None, 390 ): 391 if isinstance(result_type_or_target, Type): 392 result_type = result_type_or_target 393 target = target_or_names 394 names = names_or_none 395 else: 396 result_type = transform.AnyOpType.get() 397 target = result_type_or_target 398 names = target_or_names 399 400 if isinstance(names, str): 401 names = [names] 402 403 return cls( 404 result_type, 405 target, 406 ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), 407 loc=loc, 408 ip=ip, 409 ) 410 411 412@_ods_cext.register_operation(_Dialect, replace=True) 413class MultiTileSizesOp(MultiTileSizesOp): 414 """Specialization for MultiTileSizesOp class.""" 415 416 def __init__( 417 self, 418 result_type: Type, 419 target: Union[Operation, Value], 420 *, 421 dimension: Union[int, IntegerAttr], 422 target_size: Union[int, IntegerAttr], 423 divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, 424 loc=None, 425 ip=None, 426 ): 427 super().__init__( 428 result_type, 429 result_type, 430 result_type, 431 target, 432 dimension=dimension, 433 target_size=target_size, 434 divisor=divisor, 435 loc=loc, 436 ip=ip, 437 ) 438 439 440@_ods_cext.register_operation(_Dialect, replace=True) 441class PadOp(PadOp): 442 """Specialization for PadOp class.""" 443 444 def __init__( 445 self, 446 target: Union[Operation, OpView, Value], 447 *, 448 pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None, 449 padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, 450 padding_dimensions: OptionalIntList = None, 451 nofold_flags: OptionalIntList = None, 452 transpose_paddings: Optional[ 453 Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] 454 ] = None, 455 copy_back_op: Optional[Union[str, StringAttr]] = None, 456 loc=None, 457 ip=None, 458 ): 459 if pad_to_multiple_of is None: 460 dynamic_pad_to_multiple_of = [] 461 static_pad_to_multiple_of = None 462 else: 463 ( 464 dynamic_pad_to_multiple_of, 465 static_pad_to_multiple_of, 466 _, 467 ) = _dispatch_dynamic_index_list(pad_to_multiple_of) 468 469 transpose_paddings = _get_int_array_array_attr(transpose_paddings) 470 471 any_op_type = transform.AnyOpType.get() 472 super().__init__( 473 any_op_type, 474 any_op_type, 475 any_op_type, 476 target, 477 pad_to_multiple_of=dynamic_pad_to_multiple_of, 478 padding_values=padding_values, 479 padding_dimensions=padding_dimensions, 480 static_pad_to_multiple_of=static_pad_to_multiple_of, 481 nofold_flags=nofold_flags, 482 transpose_paddings=transpose_paddings, 483 copy_back_op=copy_back_op, 484 loc=loc, 485 ip=ip, 486 ) 487 488 489@_ods_cext.register_operation(_Dialect, replace=True) 490class ScalarizeOp(ScalarizeOp): 491 """Specialization for ScalarizeOp class.""" 492 493 def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): 494 result_type = transform.AnyOpType.get() 495 super().__init__(result_type, target, loc=loc, ip=ip) 496 497 498@_ods_cext.register_operation(_Dialect, replace=True) 499class SplitOp(SplitOp): 500 """Specialization for SplitOp class.""" 501 502 def __init__( 503 self, 504 target: Union[Operation, Value], 505 dimension: Union[int, Attribute], 506 chunk_sizes: Union[int, Operation, Value, Attribute], 507 *, 508 loc=None, 509 ip=None, 510 ): 511 if isinstance(chunk_sizes, int): 512 static_chunk_sizes = chunk_sizes 513 dynamic_chunk_sizes = None 514 else: 515 static_chunk_sizes = ShapedType.get_dynamic_size() 516 dynamic_chunk_sizes = chunk_sizes 517 518 super().__init__( 519 target.type, 520 target, 521 dimension=dimension, 522 static_chunk_sizes=static_chunk_sizes, 523 dynamic_chunk_sizes=dynamic_chunk_sizes, 524 loc=loc, 525 ip=ip, 526 ) 527 528 529@_ods_cext.register_operation(_Dialect, replace=True) 530class TileUsingForOp(TileUsingForOp): 531 """Specialization for TileUsingForOp class.""" 532 533 @overload 534 def __init__( 535 self, 536 loop_types: Union[Type, List[Type]], 537 target: Union[Operation, Value], 538 *, 539 sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, 540 interchange: OptionalIntList = None, 541 loc=None, 542 ip=None, 543 ): 544 ... 545 546 @overload 547 def __init__( 548 self, 549 target: Union[Operation, Value, OpView], 550 *, 551 sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, 552 interchange: OptionalIntList = None, 553 loc=None, 554 ip=None, 555 ): 556 ... 557 558 def __init__( 559 self, 560 loop_types_or_target: Union[Type, List[Type], Operation, Value], 561 target_or_none: Optional[Union[Operation, Value, OpView]] = None, 562 *, 563 sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, 564 interchange: OptionalIntList = None, 565 loc=None, 566 ip=None, 567 ): 568 ( 569 dynamic_sizes, 570 static_sizes, 571 scalable_sizes, 572 ) = _dispatch_dynamic_index_list(sizes) 573 574 num_loops = sum(v if v == 0 else 1 for v in static_sizes) 575 576 if isinstance(loop_types_or_target, (Operation, Value, OpView)): 577 loop_types = [transform.AnyOpType.get()] * num_loops 578 target = loop_types_or_target 579 assert ( 580 target_or_none is None 581 ), "Cannot construct TileUsingForOp with two targets." 582 else: 583 loop_types = ( 584 ([loop_types_or_target] * num_loops) 585 if isinstance(loop_types_or_target, Type) 586 else loop_types_or_target 587 ) 588 target = target_or_none 589 590 super().__init__( 591 target.type, 592 loop_types, 593 target, 594 dynamic_sizes=dynamic_sizes, 595 static_sizes=static_sizes, 596 interchange=interchange, 597 scalable_sizes=scalable_sizes, 598 loc=loc, 599 ip=ip, 600 ) 601 602 603@_ods_cext.register_operation(_Dialect, replace=True) 604class TileUsingForallOp(TileUsingForallOp): 605 """Specialization for TileUsingForallOp class.""" 606 607 @overload 608 def __init__( 609 self, 610 loops_type: Type, 611 tiled_op_type: Type, 612 target: Union[Operation, Value, OpView], 613 *, 614 num_threads: Optional[MixedValues] = None, 615 tile_sizes: MixedValues = None, 616 mapping=None, 617 loc=None, 618 ip=None, 619 ): 620 ... 621 622 @overload 623 def __init__( 624 self, 625 target: Union[Operation, Value, OpView], 626 *, 627 num_threads: Optional[MixedValues] = None, 628 tile_sizes: MixedValues = None, 629 mapping=None, 630 loc=None, 631 ip=None, 632 ): 633 ... 634 635 def __init__( 636 self, 637 loops_type_or_target: Union[ 638 Type, Union[Operation, Value, OpView] # loops_type 639 ], # target 640 tiled_op_type_or_none: Optional[Type] = None, 641 target_or_none: Optional[Union[Operation, Value, OpView]] = None, 642 *, 643 num_threads: MixedValues = None, 644 tile_sizes: MixedValues = None, 645 mapping=None, 646 loc=None, 647 ip=None, 648 ): 649 # `Type` arguments in the front are optional: add default values to front. 650 if isinstance(loops_type_or_target, Type): 651 # First overload: type arguments provided. 652 if not isinstance(tiled_op_type_or_none, Type): 653 raise TypeError( 654 "If 'loops_type_or_target' is a type, then " 655 "'tiled_op_type_or_none' is expected to be one as well." 656 ) 657 loops_type = loops_type_or_target 658 tiled_op_type = tiled_op_type_or_none 659 target = target_or_none 660 else: 661 # Last overload: type arguments missing. 662 loops_type = transform.AnyOpType.get() 663 tiled_op_type = transform.AnyOpType.get() 664 target = loops_type_or_target 665 666 # Unpack mixed num_threads. 667 ( 668 dynamic_num_threads, 669 packed_num_threads, 670 num_threads_attr, 671 ) = _dispatch_mixed_values(num_threads) 672 673 # Unpack mixed tile_sizes. 674 ( 675 dynamic_tile_sizes, 676 packed_tile_sizes, 677 tile_sizes_attr, 678 ) = _dispatch_mixed_values(tile_sizes) 679 680 super().__init__( 681 loops_type, 682 tiled_op_type, 683 target=target, 684 tile_sizes=dynamic_tile_sizes, 685 packed_tile_sizes=packed_tile_sizes, 686 static_tile_sizes=tile_sizes_attr, 687 num_threads=dynamic_num_threads, 688 packed_num_threads=packed_num_threads, 689 static_num_threads=num_threads_attr, 690 mapping=mapping, 691 loc=loc, 692 ip=ip, 693 ) 694 695 696@_ods_cext.register_operation(_Dialect, replace=True) 697class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp): 698 """Specialization for VectorizeChildrenAndApplyPatternsOp class.""" 699 700 def __init__( 701 self, 702 target: Union[Operation, Value], 703 *, 704 disable_multi_reduction_to_contract_patterns: bool = False, 705 disable_transfer_permutation_map_lowering_patterns: bool = False, 706 vectorize_nd_extract: bool = False, 707 vectorize_padding: bool = False, 708 loc=None, 709 ip=None, 710 ): 711 transformed_type = transform.AnyOpType.get() 712 super().__init__( 713 transformed_type, 714 target, 715 disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, 716 disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, 717 vectorize_nd_extract=vectorize_nd_extract, 718 vectorize_padding=vectorize_padding, 719 loc=loc, 720 ip=ip, 721 ) 722