xref: /llvm-project/mlir/python/mlir/dialects/transform/extras/__init__.py (revision 5caab8bbc0f89f46aca07be2090c8d23c78605ba)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from typing import Callable, Optional, Sequence, Union
6
7from ....extras.meta import region_op
8from .... import ir
9from ... import transform
10from .. import (
11    AnyOpType,
12    AnyParamType,
13    AnyValueType,
14    OperationType,
15    ParamType,
16    NamedSequenceOp,
17    YieldOp,
18    SequenceOp,
19    ApplyPatternsOp,
20)
21from .. import structured
22
23
24class Handle(ir.Value):
25    """
26    Base class for wrappers around different types of transform handle with
27    methods to chain further transforms.
28
29    The fields `children` and `parent` are used to capture the relation of
30    handles statically in order to enable further analysis. The payload
31    operation of a child handle is nested into a region of the payload operation
32    of the corresponding parent handle.
33    """
34
35    def __init__(
36        self,
37        v: ir.Value,
38        *,
39        parent: Optional["Handle"] = None,
40        children: Optional[Sequence["Handle"]] = None,
41    ):
42        super().__init__(v)
43        self.parent = parent
44        self.children = children if children is not None else []
45
46@ir.register_value_caster(AnyOpType.get_static_typeid())
47@ir.register_value_caster(OperationType.get_static_typeid())
48class OpHandle(Handle):
49    """
50    Wrapper around a transform operation handle with methods to chain further
51    transforms.
52    """
53
54    def __init__(
55        self,
56        v: ir.Value,
57        *,
58        parent: Optional[Handle] = None,
59        children: Optional[Sequence[Handle]] = None,
60    ):
61        super().__init__(v, parent=parent, children=children)
62
63    def get_result(self, indices: Sequence[int] = [0]) -> "ValueHandle":
64        """
65        Emits a `transform.GetResultOp`.
66        Returns a handle to the result of the payload operation at the given
67        indices.
68        """
69        get_result_op = transform.GetResultOp(
70            AnyValueType.get(),
71            self,
72            indices,
73        )
74        return get_result_op.result
75
76    def match_ops(
77        self,
78        ops: Union[
79            str,
80            ir.OpView,
81            structured.MatchInterfaceEnum,
82            Sequence[Union[str, ir.OpView]],
83        ],
84    ) -> "OpHandle":
85        """
86        Emits a `transform.structured.MatchOp`.
87        Returns a handle to payload ops that match the given names, types, or
88        interface. If only a single type is given, the value wrapped by the
89        resulting handle is populated with the respective type.
90        """
91        # Handle interface.
92        if isinstance(ops, structured.MatchInterfaceEnum) or (
93            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
94        ):
95            if isinstance(ops, str):
96                ops = structured.MatchInterfaceEnum[ops]
97            match_op = structured.MatchOp(
98                AnyOpType.get(),
99                self,
100                interface=ops,
101            )
102
103        # Handle op name(s), either given directly as string or given as op.
104        else:
105            if isinstance(ops, str):
106                op_type = OperationType.get(ops)
107                op_names = [ops]
108            elif isinstance(ops, Sequence):
109                op_type = AnyOpType.get()
110                op_names = [
111                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
112                ]
113            else:
114                op_type = OperationType.get(ops.OPERATION_NAME)
115                op_names = [ops.OPERATION_NAME]
116            match_op = structured.MatchOp.match_op_names(
117                op_type,
118                self,
119                op_names,
120            )
121
122        handle = OpHandle(match_op.results_, parent=self)
123        self.children.append(handle)
124        return handle
125
126    def print(self, name: Optional[str] = None) -> "OpHandle":
127        """
128        Emits a `transform.PrintOp` to print this handle and an optional message.
129        Returns the existing handle to facilitate further chaining.
130        """
131        transform.PrintOp(target=self, name=name)
132        return self
133
134
135@ir.register_value_caster(AnyParamType.get_static_typeid())
136@ir.register_value_caster(ParamType.get_static_typeid())
137class ParamHandle(Handle):
138    """Wrapper around a transform param handle."""
139
140    def __init__(
141        self,
142        v: ir.Value,
143        *,
144        parent: Optional[Handle] = None,
145        children: Optional[Sequence[Handle]] = None,
146    ):
147        super().__init__(v, parent=parent, children=children)
148
149
150@ir.register_value_caster(AnyValueType.get_static_typeid())
151class ValueHandle(Handle):
152    """
153    Wrapper around a transform value handle with methods to chain further
154    transforms.
155    """
156
157    def __init__(
158        self,
159        v: ir.Value,
160        *,
161        parent: Optional[Handle] = None,
162        children: Optional[Sequence[Handle]] = None,
163    ):
164        super().__init__(v, parent=parent, children=children)
165
166    def get_defining_op(self) -> OpHandle:
167        """
168        Emits a `transform.GetDefiningOpOp`.
169        Returns a handle to the defining op of the wrapped value.
170        """
171        get_defining_op = transform.GetDefiningOp(
172            AnyOpType.get(),
173            self,
174        )
175        return get_defining_op.result
176
177
178def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
179    """
180    Emits a `transform.ParamConstantOp`.
181    Returns a handle to the newly created parameter. The type of the parameter
182    is `transfrom.any_param` if the value is not an integer, otherwise the type
183    is `transform.param` parametrized with the according integer type.
184    """
185    if isinstance(value, int):
186        value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
187    if isinstance(value.type, ir.IntegerType):
188        param_type = ParamType.get(value.type)
189    else:
190        param_type = AnyParamType.get()
191    op = transform.ParamConstantOp(param_type, value)
192    return op.param
193
194
195def insert_transform_script(
196    block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
197    script: Callable[[OpHandle], None],
198    dump_script: bool = False,
199) -> None:
200    """
201    Inserts the transform script of the schedule into the module. The script
202    should accept an instance of OpHandle as argument, which will be called with
203    the block arg of the newly created named_sequence op.
204
205    Example:
206    This python code
207    ```
208    module = ir.Module.create()
209    def test_match_ops_single(module: OpHandle):
210        module.match_ops(scf.ForOp)
211    insert_transform_script(module.body, script)
212    ```
213    generates the following IR:
214    ```
215    module {
216        transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
217        ^bb0(%arg0: !transform.any_op):
218            %0 = transform.structured.match ops{["scf.for"]} in %arg0
219                 : (!transform.any_op) -> !transform.op<"scf.for">
220        }
221    }
222    ```
223    """
224    if isinstance(block_or_insertion_point, ir.Block):
225        context = block_or_insertion_point.owner.context
226        insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point)
227    else:
228        context = block_or_insertion_point.block.owner.context
229        insertion_point = block_or_insertion_point
230
231    with context, ir.Location.unknown(context):
232        with insertion_point:
233            named_sequence_op = NamedSequenceOp(
234                "__transform_main", [AnyOpType.get()], []
235            )
236        with ir.InsertionPoint(named_sequence_op.body):
237            script(named_sequence_op.bodyTarget)
238            YieldOp([])
239
240    if dump_script:
241        print(named_sequence_op)
242
243
244sequence = region_op(SequenceOp.__base__, terminator=YieldOp)
245named_sequence = region_op(NamedSequenceOp, terminator=YieldOp)
246apply_patterns = region_op(ApplyPatternsOp)
247