1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from typing import Dict, List, Sequence, Union 6 7from contextlib import contextmanager 8import functools 9import inspect 10import threading 11 12from ..... import ir 13from ...._ods_common import ( 14 get_op_result_or_value as _get_op_result_or_value, 15 get_op_results_or_values as _get_op_results_or_values, 16) 17from .comprehension import * 18from .config import * 19from .emitter import * 20 21_CONTEXT = threading.local() 22 23StructuredOpOuts = Union[ 24 ir.Operation, 25 ir.OpView, 26 ir.OpResultList, 27 Sequence[Union[ir.Value, ir.Operation, ir.OpView]], 28] 29 30 31@contextmanager 32def bind_op_def(op_def: LinalgOpDef): 33 if hasattr(_CONTEXT, "current_op_def"): 34 raise ValueError("Cannot recursively define an operation") 35 _CONTEXT.current_op_def = op_def 36 try: 37 yield op_def 38 finally: 39 del _CONTEXT.current_op_def 40 41 42def current_op_def() -> LinalgOpDef: 43 try: 44 return _CONTEXT.current_op_def 45 except AttributeError: 46 raise ValueError( 47 "Attempt to access the current op definition being defined " 48 "but none is set. Did you mean to call this in an op definition?" 49 ) 50 51 52def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList: 53 if isinstance(outs, (ir.Operation, ir.OpView)): 54 return _get_op_results_or_values(outs) 55 elif isinstance(outs, ir.OpResultList): 56 return outs 57 58 return [_get_op_result_or_value(o) for o in outs] 59 60 61class DefinedOpCallable: 62 """Callable that wraps any defined op function.""" 63 64 def __init__(self, op_name: str, op_def: LinalgOpDef): 65 self.op_name = op_name 66 self.op_def = op_def 67 68 def __call__( 69 self, 70 *ins: Union[ir.Operation, ir.OpView, ir.Value], 71 outs: StructuredOpOuts, 72 **kwargs, 73 ): 74 """Emits the corresponding op definition as IR. 75 76 Most arguments are passed through to the underlying emitter. The following 77 keyword argument is interpreted here: 78 emit_generic: Emits a generic form as appropriate (default True). If 79 False, a named form is emitted (which must have been built in to the 80 compiler). 81 """ 82 emit_generic = kwargs.pop("emit_generic", False) 83 if not isinstance(emit_generic, bool): 84 raise ValueError( 85 f"The named argument 'emit_generic' needs to be " 86 f" of type bool but got {type(emit_generic)}" 87 ) 88 89 op_configs = LinalgOpConfig.from_linalg_op_def( 90 self.op_def, context=ir.Context.current 91 ) 92 93 if len(op_configs) != 1: 94 # TODO: Support composite ops. 95 raise NotImplementedError( 96 f"Emission of composite linalg ops not supported: {op_configs}" 97 ) 98 99 ctx = ir.Context.current 100 linalgDialect = ctx.get_dialect_descriptor("linalg") 101 fully_qualified_name = "linalg." + self.op_name 102 emit_generic = emit_generic or not ctx.is_registered_operation( 103 fully_qualified_name 104 ) 105 106 op_config = op_configs[0] 107 out_values = _prepare_structured_op_outs(outs) 108 in_values = [_get_op_result_or_value(i) for i in ins] 109 if op_config.structured_op: 110 if emit_generic: 111 return emit_generic_structured_op( 112 op_config.structured_op, *in_values, outs=out_values, **kwargs 113 ) 114 else: 115 return emit_named_structured_op( 116 op_config.structured_op, 117 self.op_name, 118 self.op_def.metadata.cpp_class_name, 119 *in_values, 120 outs=out_values, 121 **kwargs, 122 ) 123 124 raise NotImplementedError( 125 f"Emission of linalg op type not supported: {op_config}" 126 ) 127 128 129def linalg_structured_op( 130 dsl_func=None, *, op_name=None, op_class_name=None 131) -> DefinedOpCallable: 132 if dsl_func is None: 133 # Curry the keyword args in for delayed application. 134 return functools.partial( 135 linalg_structured_op, op_name=op_name, op_class_name=op_class_name 136 ) 137 # Determine default names by introspecting the function. 138 if op_name is None: 139 op_name = dsl_func.__name__ 140 if op_class_name is None: 141 # Camel case it. 142 op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" 143 144 op_def = LinalgOpDef( 145 name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func) 146 ) 147 148 # Extract arguments and TensorDefs from the signature. 149 dsl_func_args = list() 150 sig = inspect.signature(dsl_func) 151 for param_name, param in sig.parameters.items(): 152 param_default = param.default 153 if isinstance( 154 param_default, 155 ( 156 TensorDef, 157 ScalarDef, 158 IndexAttrDef, 159 UnaryFnAttrDef, 160 BinaryFnAttrDef, 161 TypeFnAttrDef, 162 ), 163 ): 164 op_def.add_operand(param_name, param_default.operand_def) 165 else: 166 raise ValueError( 167 f"@linalg_structured_op function parameters must be defaulted as " 168 f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): " 169 f"Found {param_name}: {param_default}" 170 ) 171 dsl_func_args.append(param_default) 172 173 # Invoke the DSL func to finish populating the op definition. 174 with bind_op_def(op_def): 175 dsl_func(*dsl_func_args) 176 177 # TODO: The returned callable should be an IR emitter but that is not 178 # upstreamed yet. 179 return DefinedOpCallable(op_name, op_def) 180 181 182def domain(*dimensions: DimDef): 183 if any(not isinstance(d, DimDef) for d in dimensions): 184 raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") 185 current_op_def().domain.extend(dimensions) 186 187 188def implements(*interfaces: OpInterfaceDef): 189 if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): 190 raise ValueError( 191 f"Expected interfaces of type OpInterfaceDef but got {interfaces}" 192 ) 193 current_op_def().metadata.implements.extend(interfaces) 194 195 196def defines(*definitions: OpDefinitionDef): 197 if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): 198 raise ValueError( 199 f"Expected definitions of type OpDefinitionDef but got {definitions}" 200 ) 201 current_op_def().metadata.defines.extend(definitions) 202