xref: /llvm-project/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from typing import Dict, List, Sequence, Union
6
7from contextlib import contextmanager
8import functools
9import inspect
10import threading
11
12from ..... import ir
13from ...._ods_common import (
14    get_op_result_or_value as _get_op_result_or_value,
15    get_op_results_or_values as _get_op_results_or_values,
16)
17from .comprehension import *
18from .config import *
19from .emitter import *
20
21_CONTEXT = threading.local()
22
23StructuredOpOuts = Union[
24    ir.Operation,
25    ir.OpView,
26    ir.OpResultList,
27    Sequence[Union[ir.Value, ir.Operation, ir.OpView]],
28]
29
30
31@contextmanager
32def bind_op_def(op_def: LinalgOpDef):
33    if hasattr(_CONTEXT, "current_op_def"):
34        raise ValueError("Cannot recursively define an operation")
35    _CONTEXT.current_op_def = op_def
36    try:
37        yield op_def
38    finally:
39        del _CONTEXT.current_op_def
40
41
42def current_op_def() -> LinalgOpDef:
43    try:
44        return _CONTEXT.current_op_def
45    except AttributeError:
46        raise ValueError(
47            "Attempt to access the current op definition being defined "
48            "but none is set. Did you mean to call this in an op definition?"
49        )
50
51
52def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
53    if isinstance(outs, (ir.Operation, ir.OpView)):
54        return _get_op_results_or_values(outs)
55    elif isinstance(outs, ir.OpResultList):
56        return outs
57
58    return [_get_op_result_or_value(o) for o in outs]
59
60
61class DefinedOpCallable:
62    """Callable that wraps any defined op function."""
63
64    def __init__(self, op_name: str, op_def: LinalgOpDef):
65        self.op_name = op_name
66        self.op_def = op_def
67
68    def __call__(
69        self,
70        *ins: Union[ir.Operation, ir.OpView, ir.Value],
71        outs: StructuredOpOuts,
72        **kwargs,
73    ):
74        """Emits the corresponding op definition as IR.
75
76        Most arguments are passed through to the underlying emitter. The following
77        keyword argument is interpreted here:
78          emit_generic: Emits a generic form as appropriate (default True). If
79            False, a named form is emitted (which must have been built in to the
80            compiler).
81        """
82        emit_generic = kwargs.pop("emit_generic", False)
83        if not isinstance(emit_generic, bool):
84            raise ValueError(
85                f"The named argument 'emit_generic' needs to be "
86                f" of type bool but got {type(emit_generic)}"
87            )
88
89        op_configs = LinalgOpConfig.from_linalg_op_def(
90            self.op_def, context=ir.Context.current
91        )
92
93        if len(op_configs) != 1:
94            # TODO: Support composite ops.
95            raise NotImplementedError(
96                f"Emission of composite linalg ops not supported: {op_configs}"
97            )
98
99        ctx = ir.Context.current
100        linalgDialect = ctx.get_dialect_descriptor("linalg")
101        fully_qualified_name = "linalg." + self.op_name
102        emit_generic = emit_generic or not ctx.is_registered_operation(
103            fully_qualified_name
104        )
105
106        op_config = op_configs[0]
107        out_values = _prepare_structured_op_outs(outs)
108        in_values = [_get_op_result_or_value(i) for i in ins]
109        if op_config.structured_op:
110            if emit_generic:
111                return emit_generic_structured_op(
112                    op_config.structured_op, *in_values, outs=out_values, **kwargs
113                )
114            else:
115                return emit_named_structured_op(
116                    op_config.structured_op,
117                    self.op_name,
118                    self.op_def.metadata.cpp_class_name,
119                    *in_values,
120                    outs=out_values,
121                    **kwargs,
122                )
123
124        raise NotImplementedError(
125            f"Emission of linalg op type not supported: {op_config}"
126        )
127
128
129def linalg_structured_op(
130    dsl_func=None, *, op_name=None, op_class_name=None
131) -> DefinedOpCallable:
132    if dsl_func is None:
133        # Curry the keyword args in for delayed application.
134        return functools.partial(
135            linalg_structured_op, op_name=op_name, op_class_name=op_class_name
136        )
137    # Determine default names by introspecting the function.
138    if op_name is None:
139        op_name = dsl_func.__name__
140    if op_class_name is None:
141        # Camel case it.
142        op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
143
144    op_def = LinalgOpDef(
145        name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)
146    )
147
148    # Extract arguments and TensorDefs from the signature.
149    dsl_func_args = list()
150    sig = inspect.signature(dsl_func)
151    for param_name, param in sig.parameters.items():
152        param_default = param.default
153        if isinstance(
154            param_default,
155            (
156                TensorDef,
157                ScalarDef,
158                IndexAttrDef,
159                UnaryFnAttrDef,
160                BinaryFnAttrDef,
161                TypeFnAttrDef,
162            ),
163        ):
164            op_def.add_operand(param_name, param_default.operand_def)
165        else:
166            raise ValueError(
167                f"@linalg_structured_op function parameters must be defaulted as "
168                f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
169                f"Found {param_name}: {param_default}"
170            )
171        dsl_func_args.append(param_default)
172
173    # Invoke the DSL func to finish populating the op definition.
174    with bind_op_def(op_def):
175        dsl_func(*dsl_func_args)
176
177    # TODO: The returned callable should be an IR emitter but that is not
178    # upstreamed yet.
179    return DefinedOpCallable(op_name, op_def)
180
181
182def domain(*dimensions: DimDef):
183    if any(not isinstance(d, DimDef) for d in dimensions):
184        raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
185    current_op_def().domain.extend(dimensions)
186
187
188def implements(*interfaces: OpInterfaceDef):
189    if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
190        raise ValueError(
191            f"Expected interfaces of type OpInterfaceDef but got {interfaces}"
192        )
193    current_op_def().metadata.implements.extend(interfaces)
194
195
196def defines(*definitions: OpDefinitionDef):
197    if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
198        raise ValueError(
199            f"Expected definitions of type OpDefinitionDef but got {definitions}"
200        )
201    current_op_def().metadata.defines.extend(definitions)
202