xref: /llvm-project/mlir/python/mlir/dialects/func.py (revision d898ff650ae09e3ef942592aee2e87627f45d7c6)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from ._func_ops_gen import *
6from ._func_ops_gen import _Dialect
7
8try:
9    from ..ir import *
10    from ._ods_common import (
11        get_default_loc_context as _get_default_loc_context,
12        _cext as _ods_cext,
13    )
14
15    import inspect
16
17    from typing import Any, List, Optional, Sequence, Union
18except ImportError as e:
19    raise RuntimeError("Error loading imports from extension module") from e
20
21ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
22RESULT_ATTRIBUTE_NAME = "res_attrs"
23
24
25@_ods_cext.register_operation(_Dialect, replace=True)
26class ConstantOp(ConstantOp):
27    """Specialization for the constant op class."""
28
29    @property
30    def type(self):
31        return self.results[0].type
32
33
34@_ods_cext.register_operation(_Dialect, replace=True)
35class FuncOp(FuncOp):
36    """Specialization for the func op class."""
37
38    def __init__(
39        self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
40    ):
41        """
42        Create a FuncOp with the provided `name`, `type`, and `visibility`.
43        - `name` is a string representing the function name.
44        - `type` is either a FunctionType or a pair of list describing inputs and
45          results.
46        - `visibility` is a string matching `public`, `private`, or `nested`. None
47          implies private visibility.
48        - `body_builder` is an optional callback, when provided a new entry block
49          is created and the callback is invoked with the new op as argument within
50          an InsertionPoint context already set for the block. The callback is
51          expected to insert a terminator in the block.
52        """
53        sym_name = StringAttr.get(str(name))
54
55        # If the type is passed as a tuple, build a FunctionType on the fly.
56        if isinstance(type, tuple):
57            type = FunctionType.get(inputs=type[0], results=type[1])
58
59        type = TypeAttr.get(type)
60        sym_visibility = (
61            StringAttr.get(str(visibility)) if visibility is not None else None
62        )
63        super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
64        if body_builder:
65            entry_block = self.add_entry_block()
66            with InsertionPoint(entry_block):
67                body_builder(self)
68
69    @property
70    def is_external(self):
71        return len(self.regions[0].blocks) == 0
72
73    @property
74    def body(self):
75        return self.regions[0]
76
77    @property
78    def type(self):
79        return FunctionType(TypeAttr(self.attributes["function_type"]).value)
80
81    @property
82    def visibility(self):
83        return self.attributes["sym_visibility"]
84
85    @property
86    def name(self) -> StringAttr:
87        return StringAttr(self.attributes["sym_name"])
88
89    @property
90    def entry_block(self):
91        if self.is_external:
92            raise IndexError("External function does not have a body")
93        return self.regions[0].blocks[0]
94
95    def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
96        """
97        Add an entry block to the function body using the function signature to
98        infer block arguments.
99        Returns the newly created block
100        """
101        if not self.is_external:
102            raise IndexError("The function already has an entry block!")
103        self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
104        return self.body.blocks[0]
105
106    @property
107    def arg_attrs(self):
108        if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
109            return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
110        return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
111
112    @arg_attrs.setter
113    def arg_attrs(self, attribute: Union[ArrayAttr, list]):
114        if isinstance(attribute, ArrayAttr):
115            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
116        else:
117            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
118                attribute, context=self.context
119            )
120
121    @property
122    def arguments(self):
123        return self.entry_block.arguments
124
125    @property
126    def result_attrs(self):
127        return self.attributes[RESULT_ATTRIBUTE_NAME]
128
129    @result_attrs.setter
130    def result_attrs(self, attribute: ArrayAttr):
131        self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
132
133    @classmethod
134    def from_py_func(
135        FuncOp,
136        *inputs: Type,
137        results: Optional[Sequence[Type]] = None,
138        name: Optional[str] = None,
139    ):
140        """Decorator to define an MLIR FuncOp specified as a python function.
141
142        Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
143        active for the current thread (i.e. established in a `with` block).
144
145        When applied as a decorator to a Python function, an entry block will
146        be constructed for the FuncOp with types as specified in `*inputs`. The
147        block arguments will be passed positionally to the Python function. In
148        addition, if the Python function accepts keyword arguments generally or
149        has a corresponding keyword argument, the following will be passed:
150          * `func_op`: The `func` op being defined.
151
152        By default, the function name will be the Python function `__name__`. This
153        can be overriden by passing the `name` argument to the decorator.
154
155        If `results` is not specified, then the decorator will implicitly
156        insert a `ReturnOp` with the `Value`'s returned from the decorated
157        function. It will also set the `FuncOp` type with the actual return
158        value types. If `results` is specified, then the decorated function
159        must return `None` and no implicit `ReturnOp` is added (nor are the result
160        types updated). The implicit behavior is intended for simple, single-block
161        cases, and users should specify result types explicitly for any complicated
162        cases.
163
164        The decorated function can further be called from Python and will insert
165        a `CallOp` at the then-current insertion point, returning either None (
166        if no return values), a unary Value (for one result), or a list of Values).
167        This mechanism cannot be used to emit recursive calls (by construction).
168        """
169
170        def decorator(f):
171            from . import func
172
173            # Introspect the callable for optional features.
174            sig = inspect.signature(f)
175            has_arg_func_op = False
176            for param in sig.parameters.values():
177                if param.kind == param.VAR_KEYWORD:
178                    has_arg_func_op = True
179                if param.name == "func_op" and (
180                    param.kind == param.POSITIONAL_OR_KEYWORD
181                    or param.kind == param.KEYWORD_ONLY
182                ):
183                    has_arg_func_op = True
184
185            # Emit the FuncOp.
186            implicit_return = results is None
187            symbol_name = name or f.__name__
188            function_type = FunctionType.get(
189                inputs=inputs, results=[] if implicit_return else results
190            )
191            func_op = FuncOp(name=symbol_name, type=function_type)
192            with InsertionPoint(func_op.add_entry_block()):
193                func_args = func_op.entry_block.arguments
194                func_kwargs = {}
195                if has_arg_func_op:
196                    func_kwargs["func_op"] = func_op
197                return_values = f(*func_args, **func_kwargs)
198                if not implicit_return:
199                    return_types = list(results)
200                    assert return_values is None, (
201                        "Capturing a python function with explicit `results=` "
202                        "requires that the wrapped function returns None."
203                    )
204                else:
205                    # Coerce return values, add ReturnOp and rewrite func type.
206                    if return_values is None:
207                        return_values = []
208                    elif isinstance(return_values, tuple):
209                        return_values = list(return_values)
210                    elif isinstance(return_values, Value):
211                        # Returning a single value is fine, coerce it into a list.
212                        return_values = [return_values]
213                    elif isinstance(return_values, OpView):
214                        # Returning a single operation is fine, coerce its results a list.
215                        return_values = return_values.operation.results
216                    elif isinstance(return_values, Operation):
217                        # Returning a single operation is fine, coerce its results a list.
218                        return_values = return_values.results
219                    else:
220                        return_values = list(return_values)
221                    func.ReturnOp(return_values)
222                    # Recompute the function type.
223                    return_types = [v.type for v in return_values]
224                    function_type = FunctionType.get(
225                        inputs=inputs, results=return_types
226                    )
227                    func_op.attributes["function_type"] = TypeAttr.get(function_type)
228
229            def emit_call_op(*call_args):
230                call_op = func.CallOp(
231                    return_types, FlatSymbolRefAttr.get(symbol_name), call_args
232                )
233                if return_types is None:
234                    return None
235                elif len(return_types) == 1:
236                    return call_op.result
237                else:
238                    return call_op.results
239
240            wrapped = emit_call_op
241            wrapped.__name__ = f.__name__
242            wrapped.func_op = func_op
243            return wrapped
244
245        return decorator
246
247
248func = FuncOp.from_py_func
249
250
251@_ods_cext.register_operation(_Dialect, replace=True)
252class CallOp(CallOp):
253    """Specialization for the call op class."""
254
255    def __init__(
256        self,
257        calleeOrResults: Union[FuncOp, List[Type]],
258        argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
259        arguments: Optional[List] = None,
260        *,
261        loc=None,
262        ip=None,
263    ):
264        """Creates an call operation.
265
266        The constructor accepts three different forms:
267
268          1. A function op to be called followed by a list of arguments.
269          2. A list of result types, followed by the name of the function to be
270             called as string, following by a list of arguments.
271          3. A list of result types, followed by the name of the function to be
272             called as symbol reference attribute, followed by a list of arguments.
273
274        For example
275
276            f = func.FuncOp("foo", ...)
277            func.CallOp(f, [args])
278            func.CallOp([result_types], "foo", [args])
279
280        In all cases, the location and insertion point may be specified as keyword
281        arguments if not provided by the surrounding context managers.
282        """
283
284        # TODO: consider supporting constructor "overloads", e.g., through a custom
285        # or pybind-provided metaclass.
286        if isinstance(calleeOrResults, FuncOp):
287            if not isinstance(argumentsOrCallee, list):
288                raise ValueError(
289                    "when constructing a call to a function, expected "
290                    + "the second argument to be a list of call arguments, "
291                    + f"got {type(argumentsOrCallee)}"
292                )
293            if arguments is not None:
294                raise ValueError(
295                    "unexpected third argument when constructing a call"
296                    + "to a function"
297                )
298
299            super().__init__(
300                calleeOrResults.type.results,
301                FlatSymbolRefAttr.get(
302                    calleeOrResults.name.value, context=_get_default_loc_context(loc)
303                ),
304                argumentsOrCallee,
305                loc=loc,
306                ip=ip,
307            )
308            return
309
310        if isinstance(argumentsOrCallee, list):
311            raise ValueError(
312                "when constructing a call to a function by name, "
313                + "expected the second argument to be a string or a "
314                + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
315            )
316
317        if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
318            super().__init__(
319                calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
320            )
321        elif isinstance(argumentsOrCallee, str):
322            super().__init__(
323                calleeOrResults,
324                FlatSymbolRefAttr.get(
325                    argumentsOrCallee, context=_get_default_loc_context(loc)
326                ),
327                arguments,
328                loc=loc,
329                ip=ip,
330            )
331