xref: /llvm-project/mlir/python/mlir/dialects/transform/structured.py (revision 579ced4f8266b273d15b2801067a828151a222ef)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from .._structured_transform_ops_gen import *
6from .._structured_transform_ops_gen import _Dialect
7from .._structured_transform_enum_gen import *
8
9try:
10    from ...ir import *
11    from ...dialects import transform
12    from .._ods_common import (
13        DynamicIndexList,
14        IntOrAttrList,
15        MixedValues,
16        OptionalBoolList,
17        OptionalIntList,
18        _cext as _ods_cext,
19        _dispatch_dynamic_index_list,
20        _dispatch_mixed_values,
21        _get_int_array_array_attr,
22        _get_int_array_attr,
23        _get_value_list,
24        _get_value_or_attribute_value,
25    )
26except ImportError as e:
27    raise RuntimeError("Error loading imports from extension module") from e
28
29from typing import List, Optional, Sequence, Union, overload
30
31
32@_ods_cext.register_operation(_Dialect, replace=True)
33class BufferizeToAllocationOp(BufferizeToAllocationOp):
34    """Specialization for BufferizeToAllocationOp class."""
35
36    def __init__(
37        self,
38        target: Union[Operation, OpView, Value],
39        *,
40        memory_space: Optional[Union[int, str, Attribute]] = None,
41        memcpy_op: Optional[str] = None,
42        alloc_op: Optional[str] = None,
43        bufferize_destination_only: Optional[bool] = None,
44        loc=None,
45        ip=None,
46    ):
47        # No other types are allowed, so hard-code those here.
48        allocated_buffer_type = transform.AnyValueType.get()
49        new_ops_type = transform.AnyOpType.get()
50
51        if isinstance(memory_space, int):
52            memory_space = str(memory_space)
53        if isinstance(memory_space, str):
54            memory_space = Attribute.parse(memory_space)
55
56        super().__init__(
57            allocated_buffer_type,
58            new_ops_type,
59            target,
60            memory_space=memory_space,
61            memcpy_op=memcpy_op,
62            alloc_op=alloc_op,
63            bufferize_destination_only=bufferize_destination_only,
64            loc=loc,
65            ip=ip,
66        )
67
68
69@_ods_cext.register_operation(_Dialect, replace=True)
70class DecomposeOp(DecomposeOp):
71    """Specialization for DecomposeOp class."""
72
73    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
74        transformed_type = transform.AnyOpType.get()
75        super().__init__(transformed_type, target, loc=loc, ip=ip)
76
77
78@_ods_cext.register_operation(_Dialect, replace=True)
79class FuseIntoContainingOp(FuseIntoContainingOp):
80    """Specialization for FuseIntoContainingOp class."""
81
82    @overload
83    def __init__(
84        self,
85        fused_op_type: Type,
86        new_containing_op_type: Type,
87        producer_op: Union[Operation, OpView, Value],
88        containing_op: Union[Operation, OpView, Value],
89        *,
90        loc=None,
91        ip=None,
92    ):
93        ...
94
95    @overload
96    def __init__(
97        self,
98        producer_op: Union[Operation, OpView, Value],
99        containing_op: Union[Operation, OpView, Value],
100        *,
101        loc=None,
102        ip=None,
103    ):
104        ...
105
106    def __init__(
107        self,
108        fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
109        new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
110        producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
111        containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
112        *,
113        loc=None,
114        ip=None,
115    ):
116        if isinstance(fused_op_type_or_producer_op, Type):
117            if not isinstance(new_containing_op_type_or_containing_op, Type):
118                raise TypeError(
119                    "If 'fused_op_type_or_producer_op' is a type, then "
120                    "'new_containing_op_type_or_containing_op' is expected "
121                    "to be one as well."
122                )
123            fused_op_type = fused_op_type_or_producer_op
124            new_containing_op_type = new_containing_op_type_or_containing_op
125            producer_op = producer_op_or_none
126            containing_op = containing_op_or_none
127        else:
128            fused_op_type = transform.AnyOpType.get()
129            new_containing_op_type = transform.AnyOpType.get()
130            producer_op = fused_op_type_or_producer_op
131            containing_op = new_containing_op_type_or_containing_op
132
133        super().__init__(
134            fused_op_type,
135            new_containing_op_type,
136            producer_op,
137            containing_op,
138            loc=loc,
139            ip=ip,
140        )
141
142
143@_ods_cext.register_operation(_Dialect, replace=True)
144class FuseOp(FuseOp):
145    """Specialization for FuseOp class."""
146
147    @overload
148    def __init__(
149        self,
150        loop_types: Union[Type, Sequence[Type]],
151        target: Union[Operation, Value, OpView],
152        *,
153        tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
154        tile_interchange: OptionalIntList = None,
155        apply_cleanup: Optional[bool] = False,
156        loc=None,
157        ip=None,
158    ):
159        ...
160
161    @overload
162    def __init__(
163        self,
164        target: Union[Operation, Value, OpView],
165        *,
166        tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
167        tile_interchange: OptionalIntList = None,
168        apply_cleanup: Optional[bool] = False,
169        loc=None,
170        ip=None,
171    ):
172        ...
173
174    def __init__(
175        self,
176        loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
177        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
178        *,
179        tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
180        tile_interchange: OptionalIntList = None,
181        apply_cleanup: Optional[bool] = False,
182        loc=None,
183        ip=None,
184    ):
185        tile_sizes = tile_sizes if tile_sizes else []
186        tile_interchange = tile_interchange if tile_interchange else []
187        _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
188        _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
189        num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)
190
191        if isinstance(loop_types_or_target, (Operation, Value, OpView)):
192            loop_types = [transform.AnyOpType.get()] * num_loops
193            target = loop_types_or_target
194            assert target_or_none is None, "Cannot construct FuseOp with two targets."
195        else:
196            loop_types = (
197                ([loop_types_or_target] * num_loops)
198                if isinstance(loop_types_or_target, Type)
199                else loop_types_or_target
200            )
201            target = target_or_none
202        super().__init__(
203            target.type,
204            loop_types,
205            target,
206            tile_sizes=tile_sizes,
207            tile_interchange=tile_interchange,
208            apply_cleanup=apply_cleanup,
209            loc=loc,
210            ip=ip,
211        )
212
213
214@_ods_cext.register_operation(_Dialect, replace=True)
215class GeneralizeOp(GeneralizeOp):
216    """Specialization for GeneralizeOp class."""
217
218    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
219        transformed_type = transform.AnyOpType.get()
220        super().__init__(transformed_type, target, loc=loc, ip=ip)
221
222
223@_ods_cext.register_operation(_Dialect, replace=True)
224class InterchangeOp(InterchangeOp):
225    """Specialization for InterchangeOp class."""
226
227    def __init__(
228        self,
229        target: Union[Operation, Value],
230        *,
231        iterator_interchange: OptionalIntList = None,
232        loc=None,
233        ip=None,
234    ):
235        transformed_type = transform.AnyOpType.get()
236        super().__init__(
237            transformed_type,
238            target,
239            iterator_interchange=iterator_interchange,
240            loc=loc,
241            ip=ip,
242        )
243
244
245@_ods_cext.register_operation(_Dialect, replace=True)
246class MapCopyToThreadsOp(MapCopyToThreadsOp):
247    """Specialization for MapCopyToThreadsOp class."""
248
249    @overload
250    def __init__(
251        self,
252        forall_op_type: Type,
253        tiled_op_type: Type,
254        target: Union[Operation, OpView, Value],
255        *,
256        total_num_threads: Union[int, IntegerAttr],
257        desired_bit_alignment: Union[int, IntegerAttr],
258        loc=None,
259        ip=None,
260    ):
261        ...
262
263    @overload
264    def __init__(
265        self,
266        target: Union[Operation, OpView, Value],
267        *,
268        total_num_threads: Union[int, IntegerAttr],
269        desired_bit_alignment: Union[int, IntegerAttr],
270        loc=None,
271        ip=None,
272    ):
273        ...
274
275    def __init__(
276        self,
277        forall_op_type_or_target: Union[Operation, OpView, Type, Value],
278        tiled_op_type_or_none: Optional[Type] = None,
279        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
280        *,
281        total_num_threads: Union[int, IntegerAttr],
282        desired_bit_alignment: Union[int, IntegerAttr],
283        loc=None,
284        ip=None,
285    ):
286        if isinstance(forall_op_type_or_target, Type):
287            forall_op_type = forall_op_type_or_target
288            tiled_op_type = tiled_op_type_or_none
289            target = target_or_none
290        else:
291            forall_op_type = transform.AnyOpType.get()
292            tiled_op_type = transform.AnyOpType.get()
293            target = forall_op_type_or_target
294
295        super().__init__(
296            forall_op_type,
297            tiled_op_type,
298            target,
299            total_num_threads=total_num_threads,
300            desired_bit_alignment=desired_bit_alignment,
301            loc=loc,
302            ip=ip,
303        )
304
305
306@_ods_cext.register_operation(_Dialect, replace=True)
307class VectorizeOp(VectorizeOp):
308    """Specialization for VectorizeOp class."""
309
310    def __init__(
311        self,
312        target: Union[Operation, OpView, Value],
313        vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
314        *,
315        vectorize_nd_extract: Optional[bool] = None,
316        scalable_sizes: OptionalBoolList = None,
317        static_vector_sizes: OptionalIntList = None,
318        loc=None,
319        ip=None,
320    ):
321        if (
322            scalable_sizes is None
323            and static_vector_sizes is None
324            and vector_sizes is None
325        ):
326            dynamic_vector_sizes = []
327        elif scalable_sizes is None and static_vector_sizes is None:
328            (
329                dynamic_vector_sizes,
330                static_vector_sizes,
331                scalable_sizes,
332            ) = _dispatch_dynamic_index_list(vector_sizes)
333        elif scalable_sizes is None or static_vector_sizes is None:
334            raise TypeError(
335                "'scalable_sizes' and 'static_vector_sizes' must either both "
336                "be given explicitly or both be given as part of 'vector_sizes'."
337            )
338        else:
339            dynamic_vector_sizes = vector_sizes
340
341        super().__init__(
342            target,
343            vector_sizes=dynamic_vector_sizes,
344            static_vector_sizes=static_vector_sizes,
345            scalable_sizes=scalable_sizes,
346            vectorize_nd_extract=vectorize_nd_extract,
347            loc=loc,
348            ip=ip,
349        )
350
351
352@_ods_cext.register_operation(_Dialect, replace=True)
353class MatchOp(MatchOp):
354    """Specialization for MatchOp class."""
355
356    @overload
357    @classmethod
358    def match_op_names(
359        cls,
360        target: Union[Operation, Value],
361        names: Union[str, Sequence[str]],
362        *,
363        loc=None,
364        ip=None,
365    ):
366        ...
367
368    @overload
369    @classmethod
370    def match_op_names(
371        cls,
372        result_type: Type,
373        target: Union[Operation, Value],
374        names: Union[str, Sequence[str]],
375        *,
376        loc=None,
377        ip=None,
378    ):
379        ...
380
381    @classmethod
382    def match_op_names(
383        cls,
384        result_type_or_target: Union[Type, Operation, Value],
385        target_or_names: Union[Operation, Value, Sequence[str], str],
386        names_or_none: Optional[Union[Sequence[str], str]] = None,
387        *,
388        loc=None,
389        ip=None,
390    ):
391        if isinstance(result_type_or_target, Type):
392            result_type = result_type_or_target
393            target = target_or_names
394            names = names_or_none
395        else:
396            result_type = transform.AnyOpType.get()
397            target = result_type_or_target
398            names = target_or_names
399
400        if isinstance(names, str):
401            names = [names]
402
403        return cls(
404            result_type,
405            target,
406            ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
407            loc=loc,
408            ip=ip,
409        )
410
411
412@_ods_cext.register_operation(_Dialect, replace=True)
413class MultiTileSizesOp(MultiTileSizesOp):
414    """Specialization for MultiTileSizesOp class."""
415
416    def __init__(
417        self,
418        result_type: Type,
419        target: Union[Operation, Value],
420        *,
421        dimension: Union[int, IntegerAttr],
422        target_size: Union[int, IntegerAttr],
423        divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
424        loc=None,
425        ip=None,
426    ):
427        super().__init__(
428            result_type,
429            result_type,
430            result_type,
431            target,
432            dimension=dimension,
433            target_size=target_size,
434            divisor=divisor,
435            loc=loc,
436            ip=ip,
437        )
438
439
440@_ods_cext.register_operation(_Dialect, replace=True)
441class PadOp(PadOp):
442    """Specialization for PadOp class."""
443
444    def __init__(
445        self,
446        target: Union[Operation, OpView, Value],
447        *,
448        pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
449        padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
450        padding_dimensions: OptionalIntList = None,
451        nofold_flags: OptionalIntList = None,
452        transpose_paddings: Optional[
453            Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
454        ] = None,
455        copy_back_op: Optional[Union[str, StringAttr]] = None,
456        loc=None,
457        ip=None,
458    ):
459        if pad_to_multiple_of is None:
460            dynamic_pad_to_multiple_of = []
461            static_pad_to_multiple_of = None
462        else:
463            (
464                dynamic_pad_to_multiple_of,
465                static_pad_to_multiple_of,
466                _,
467            ) = _dispatch_dynamic_index_list(pad_to_multiple_of)
468
469        transpose_paddings = _get_int_array_array_attr(transpose_paddings)
470
471        any_op_type = transform.AnyOpType.get()
472        super().__init__(
473            any_op_type,
474            any_op_type,
475            any_op_type,
476            target,
477            pad_to_multiple_of=dynamic_pad_to_multiple_of,
478            padding_values=padding_values,
479            padding_dimensions=padding_dimensions,
480            static_pad_to_multiple_of=static_pad_to_multiple_of,
481            nofold_flags=nofold_flags,
482            transpose_paddings=transpose_paddings,
483            copy_back_op=copy_back_op,
484            loc=loc,
485            ip=ip,
486        )
487
488
489@_ods_cext.register_operation(_Dialect, replace=True)
490class ScalarizeOp(ScalarizeOp):
491    """Specialization for ScalarizeOp class."""
492
493    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
494        result_type = transform.AnyOpType.get()
495        super().__init__(result_type, target, loc=loc, ip=ip)
496
497
498@_ods_cext.register_operation(_Dialect, replace=True)
499class SplitOp(SplitOp):
500    """Specialization for SplitOp class."""
501
502    def __init__(
503        self,
504        target: Union[Operation, Value],
505        dimension: Union[int, Attribute],
506        chunk_sizes: Union[int, Operation, Value, Attribute],
507        *,
508        loc=None,
509        ip=None,
510    ):
511        if isinstance(chunk_sizes, int):
512            static_chunk_sizes = chunk_sizes
513            dynamic_chunk_sizes = None
514        else:
515            static_chunk_sizes = ShapedType.get_dynamic_size()
516            dynamic_chunk_sizes = chunk_sizes
517
518        super().__init__(
519            target.type,
520            target,
521            dimension=dimension,
522            static_chunk_sizes=static_chunk_sizes,
523            dynamic_chunk_sizes=dynamic_chunk_sizes,
524            loc=loc,
525            ip=ip,
526        )
527
528
529@_ods_cext.register_operation(_Dialect, replace=True)
530class TileUsingForOp(TileUsingForOp):
531    """Specialization for TileUsingForOp class."""
532
533    @overload
534    def __init__(
535        self,
536        loop_types: Union[Type, List[Type]],
537        target: Union[Operation, Value],
538        *,
539        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
540        interchange: OptionalIntList = None,
541        loc=None,
542        ip=None,
543    ):
544        ...
545
546    @overload
547    def __init__(
548        self,
549        target: Union[Operation, Value, OpView],
550        *,
551        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
552        interchange: OptionalIntList = None,
553        loc=None,
554        ip=None,
555    ):
556        ...
557
558    def __init__(
559        self,
560        loop_types_or_target: Union[Type, List[Type], Operation, Value],
561        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
562        *,
563        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
564        interchange: OptionalIntList = None,
565        loc=None,
566        ip=None,
567    ):
568        (
569            dynamic_sizes,
570            static_sizes,
571            scalable_sizes,
572        ) = _dispatch_dynamic_index_list(sizes)
573
574        num_loops = sum(v if v == 0 else 1 for v in static_sizes)
575
576        if isinstance(loop_types_or_target, (Operation, Value, OpView)):
577            loop_types = [transform.AnyOpType.get()] * num_loops
578            target = loop_types_or_target
579            assert (
580                target_or_none is None
581            ), "Cannot construct TileUsingForOp with two targets."
582        else:
583            loop_types = (
584                ([loop_types_or_target] * num_loops)
585                if isinstance(loop_types_or_target, Type)
586                else loop_types_or_target
587            )
588            target = target_or_none
589
590        super().__init__(
591            target.type,
592            loop_types,
593            target,
594            dynamic_sizes=dynamic_sizes,
595            static_sizes=static_sizes,
596            interchange=interchange,
597            scalable_sizes=scalable_sizes,
598            loc=loc,
599            ip=ip,
600        )
601
602
603@_ods_cext.register_operation(_Dialect, replace=True)
604class TileUsingForallOp(TileUsingForallOp):
605    """Specialization for TileUsingForallOp class."""
606
607    @overload
608    def __init__(
609        self,
610        loops_type: Type,
611        tiled_op_type: Type,
612        target: Union[Operation, Value, OpView],
613        *,
614        num_threads: Optional[MixedValues] = None,
615        tile_sizes: MixedValues = None,
616        mapping=None,
617        loc=None,
618        ip=None,
619    ):
620        ...
621
622    @overload
623    def __init__(
624        self,
625        target: Union[Operation, Value, OpView],
626        *,
627        num_threads: Optional[MixedValues] = None,
628        tile_sizes: MixedValues = None,
629        mapping=None,
630        loc=None,
631        ip=None,
632    ):
633        ...
634
635    def __init__(
636        self,
637        loops_type_or_target: Union[
638            Type, Union[Operation, Value, OpView]  # loops_type
639        ],  # target
640        tiled_op_type_or_none: Optional[Type] = None,
641        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
642        *,
643        num_threads: MixedValues = None,
644        tile_sizes: MixedValues = None,
645        mapping=None,
646        loc=None,
647        ip=None,
648    ):
649        # `Type` arguments in the front are optional: add default values to front.
650        if isinstance(loops_type_or_target, Type):
651            # First overload: type arguments provided.
652            if not isinstance(tiled_op_type_or_none, Type):
653                raise TypeError(
654                    "If 'loops_type_or_target' is a type, then "
655                    "'tiled_op_type_or_none' is expected to be one as well."
656                )
657            loops_type = loops_type_or_target
658            tiled_op_type = tiled_op_type_or_none
659            target = target_or_none
660        else:
661            # Last overload: type arguments missing.
662            loops_type = transform.AnyOpType.get()
663            tiled_op_type = transform.AnyOpType.get()
664            target = loops_type_or_target
665
666        # Unpack mixed num_threads.
667        (
668            dynamic_num_threads,
669            packed_num_threads,
670            num_threads_attr,
671        ) = _dispatch_mixed_values(num_threads)
672
673        # Unpack mixed tile_sizes.
674        (
675            dynamic_tile_sizes,
676            packed_tile_sizes,
677            tile_sizes_attr,
678        ) = _dispatch_mixed_values(tile_sizes)
679
680        super().__init__(
681            loops_type,
682            tiled_op_type,
683            target=target,
684            tile_sizes=dynamic_tile_sizes,
685            packed_tile_sizes=packed_tile_sizes,
686            static_tile_sizes=tile_sizes_attr,
687            num_threads=dynamic_num_threads,
688            packed_num_threads=packed_num_threads,
689            static_num_threads=num_threads_attr,
690            mapping=mapping,
691            loc=loc,
692            ip=ip,
693        )
694
695
696@_ods_cext.register_operation(_Dialect, replace=True)
697class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp):
698    """Specialization for VectorizeChildrenAndApplyPatternsOp class."""
699
700    def __init__(
701        self,
702        target: Union[Operation, Value],
703        *,
704        disable_multi_reduction_to_contract_patterns: bool = False,
705        disable_transfer_permutation_map_lowering_patterns: bool = False,
706        vectorize_nd_extract: bool = False,
707        vectorize_padding: bool = False,
708        loc=None,
709        ip=None,
710    ):
711        transformed_type = transform.AnyOpType.get()
712        super().__init__(
713            transformed_type,
714            target,
715            disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
716            disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
717            vectorize_nd_extract=vectorize_nd_extract,
718            vectorize_padding=vectorize_padding,
719            loc=loc,
720            ip=ip,
721        )
722