1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from ._func_ops_gen import * 6from ._func_ops_gen import _Dialect 7 8try: 9 from ..ir import * 10 from ._ods_common import ( 11 get_default_loc_context as _get_default_loc_context, 12 _cext as _ods_cext, 13 ) 14 15 import inspect 16 17 from typing import Any, List, Optional, Sequence, Union 18except ImportError as e: 19 raise RuntimeError("Error loading imports from extension module") from e 20 21ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" 22RESULT_ATTRIBUTE_NAME = "res_attrs" 23 24 25@_ods_cext.register_operation(_Dialect, replace=True) 26class ConstantOp(ConstantOp): 27 """Specialization for the constant op class.""" 28 29 @property 30 def type(self): 31 return self.results[0].type 32 33 34@_ods_cext.register_operation(_Dialect, replace=True) 35class FuncOp(FuncOp): 36 """Specialization for the func op class.""" 37 38 def __init__( 39 self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None 40 ): 41 """ 42 Create a FuncOp with the provided `name`, `type`, and `visibility`. 43 - `name` is a string representing the function name. 44 - `type` is either a FunctionType or a pair of list describing inputs and 45 results. 46 - `visibility` is a string matching `public`, `private`, or `nested`. None 47 implies private visibility. 48 - `body_builder` is an optional callback, when provided a new entry block 49 is created and the callback is invoked with the new op as argument within 50 an InsertionPoint context already set for the block. The callback is 51 expected to insert a terminator in the block. 52 """ 53 sym_name = StringAttr.get(str(name)) 54 55 # If the type is passed as a tuple, build a FunctionType on the fly. 56 if isinstance(type, tuple): 57 type = FunctionType.get(inputs=type[0], results=type[1]) 58 59 type = TypeAttr.get(type) 60 sym_visibility = ( 61 StringAttr.get(str(visibility)) if visibility is not None else None 62 ) 63 super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) 64 if body_builder: 65 entry_block = self.add_entry_block() 66 with InsertionPoint(entry_block): 67 body_builder(self) 68 69 @property 70 def is_external(self): 71 return len(self.regions[0].blocks) == 0 72 73 @property 74 def body(self): 75 return self.regions[0] 76 77 @property 78 def type(self): 79 return FunctionType(TypeAttr(self.attributes["function_type"]).value) 80 81 @property 82 def visibility(self): 83 return self.attributes["sym_visibility"] 84 85 @property 86 def name(self) -> StringAttr: 87 return StringAttr(self.attributes["sym_name"]) 88 89 @property 90 def entry_block(self): 91 if self.is_external: 92 raise IndexError("External function does not have a body") 93 return self.regions[0].blocks[0] 94 95 def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): 96 """ 97 Add an entry block to the function body using the function signature to 98 infer block arguments. 99 Returns the newly created block 100 """ 101 if not self.is_external: 102 raise IndexError("The function already has an entry block!") 103 self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) 104 return self.body.blocks[0] 105 106 @property 107 def arg_attrs(self): 108 if ARGUMENT_ATTRIBUTE_NAME not in self.attributes: 109 return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs]) 110 return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) 111 112 @arg_attrs.setter 113 def arg_attrs(self, attribute: Union[ArrayAttr, list]): 114 if isinstance(attribute, ArrayAttr): 115 self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute 116 else: 117 self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( 118 attribute, context=self.context 119 ) 120 121 @property 122 def arguments(self): 123 return self.entry_block.arguments 124 125 @property 126 def result_attrs(self): 127 return self.attributes[RESULT_ATTRIBUTE_NAME] 128 129 @result_attrs.setter 130 def result_attrs(self, attribute: ArrayAttr): 131 self.attributes[RESULT_ATTRIBUTE_NAME] = attribute 132 133 @classmethod 134 def from_py_func( 135 FuncOp, 136 *inputs: Type, 137 results: Optional[Sequence[Type]] = None, 138 name: Optional[str] = None, 139 ): 140 """Decorator to define an MLIR FuncOp specified as a python function. 141 142 Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are 143 active for the current thread (i.e. established in a `with` block). 144 145 When applied as a decorator to a Python function, an entry block will 146 be constructed for the FuncOp with types as specified in `*inputs`. The 147 block arguments will be passed positionally to the Python function. In 148 addition, if the Python function accepts keyword arguments generally or 149 has a corresponding keyword argument, the following will be passed: 150 * `func_op`: The `func` op being defined. 151 152 By default, the function name will be the Python function `__name__`. This 153 can be overriden by passing the `name` argument to the decorator. 154 155 If `results` is not specified, then the decorator will implicitly 156 insert a `ReturnOp` with the `Value`'s returned from the decorated 157 function. It will also set the `FuncOp` type with the actual return 158 value types. If `results` is specified, then the decorated function 159 must return `None` and no implicit `ReturnOp` is added (nor are the result 160 types updated). The implicit behavior is intended for simple, single-block 161 cases, and users should specify result types explicitly for any complicated 162 cases. 163 164 The decorated function can further be called from Python and will insert 165 a `CallOp` at the then-current insertion point, returning either None ( 166 if no return values), a unary Value (for one result), or a list of Values). 167 This mechanism cannot be used to emit recursive calls (by construction). 168 """ 169 170 def decorator(f): 171 from . import func 172 173 # Introspect the callable for optional features. 174 sig = inspect.signature(f) 175 has_arg_func_op = False 176 for param in sig.parameters.values(): 177 if param.kind == param.VAR_KEYWORD: 178 has_arg_func_op = True 179 if param.name == "func_op" and ( 180 param.kind == param.POSITIONAL_OR_KEYWORD 181 or param.kind == param.KEYWORD_ONLY 182 ): 183 has_arg_func_op = True 184 185 # Emit the FuncOp. 186 implicit_return = results is None 187 symbol_name = name or f.__name__ 188 function_type = FunctionType.get( 189 inputs=inputs, results=[] if implicit_return else results 190 ) 191 func_op = FuncOp(name=symbol_name, type=function_type) 192 with InsertionPoint(func_op.add_entry_block()): 193 func_args = func_op.entry_block.arguments 194 func_kwargs = {} 195 if has_arg_func_op: 196 func_kwargs["func_op"] = func_op 197 return_values = f(*func_args, **func_kwargs) 198 if not implicit_return: 199 return_types = list(results) 200 assert return_values is None, ( 201 "Capturing a python function with explicit `results=` " 202 "requires that the wrapped function returns None." 203 ) 204 else: 205 # Coerce return values, add ReturnOp and rewrite func type. 206 if return_values is None: 207 return_values = [] 208 elif isinstance(return_values, tuple): 209 return_values = list(return_values) 210 elif isinstance(return_values, Value): 211 # Returning a single value is fine, coerce it into a list. 212 return_values = [return_values] 213 elif isinstance(return_values, OpView): 214 # Returning a single operation is fine, coerce its results a list. 215 return_values = return_values.operation.results 216 elif isinstance(return_values, Operation): 217 # Returning a single operation is fine, coerce its results a list. 218 return_values = return_values.results 219 else: 220 return_values = list(return_values) 221 func.ReturnOp(return_values) 222 # Recompute the function type. 223 return_types = [v.type for v in return_values] 224 function_type = FunctionType.get( 225 inputs=inputs, results=return_types 226 ) 227 func_op.attributes["function_type"] = TypeAttr.get(function_type) 228 229 def emit_call_op(*call_args): 230 call_op = func.CallOp( 231 return_types, FlatSymbolRefAttr.get(symbol_name), call_args 232 ) 233 if return_types is None: 234 return None 235 elif len(return_types) == 1: 236 return call_op.result 237 else: 238 return call_op.results 239 240 wrapped = emit_call_op 241 wrapped.__name__ = f.__name__ 242 wrapped.func_op = func_op 243 return wrapped 244 245 return decorator 246 247 248func = FuncOp.from_py_func 249 250 251@_ods_cext.register_operation(_Dialect, replace=True) 252class CallOp(CallOp): 253 """Specialization for the call op class.""" 254 255 def __init__( 256 self, 257 calleeOrResults: Union[FuncOp, List[Type]], 258 argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], 259 arguments: Optional[List] = None, 260 *, 261 loc=None, 262 ip=None, 263 ): 264 """Creates an call operation. 265 266 The constructor accepts three different forms: 267 268 1. A function op to be called followed by a list of arguments. 269 2. A list of result types, followed by the name of the function to be 270 called as string, following by a list of arguments. 271 3. A list of result types, followed by the name of the function to be 272 called as symbol reference attribute, followed by a list of arguments. 273 274 For example 275 276 f = func.FuncOp("foo", ...) 277 func.CallOp(f, [args]) 278 func.CallOp([result_types], "foo", [args]) 279 280 In all cases, the location and insertion point may be specified as keyword 281 arguments if not provided by the surrounding context managers. 282 """ 283 284 # TODO: consider supporting constructor "overloads", e.g., through a custom 285 # or pybind-provided metaclass. 286 if isinstance(calleeOrResults, FuncOp): 287 if not isinstance(argumentsOrCallee, list): 288 raise ValueError( 289 "when constructing a call to a function, expected " 290 + "the second argument to be a list of call arguments, " 291 + f"got {type(argumentsOrCallee)}" 292 ) 293 if arguments is not None: 294 raise ValueError( 295 "unexpected third argument when constructing a call" 296 + "to a function" 297 ) 298 299 super().__init__( 300 calleeOrResults.type.results, 301 FlatSymbolRefAttr.get( 302 calleeOrResults.name.value, context=_get_default_loc_context(loc) 303 ), 304 argumentsOrCallee, 305 loc=loc, 306 ip=ip, 307 ) 308 return 309 310 if isinstance(argumentsOrCallee, list): 311 raise ValueError( 312 "when constructing a call to a function by name, " 313 + "expected the second argument to be a string or a " 314 + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}" 315 ) 316 317 if isinstance(argumentsOrCallee, FlatSymbolRefAttr): 318 super().__init__( 319 calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip 320 ) 321 elif isinstance(argumentsOrCallee, str): 322 super().__init__( 323 calleeOrResults, 324 FlatSymbolRefAttr.get( 325 argumentsOrCallee, context=_get_default_loc_context(loc) 326 ), 327 arguments, 328 loc=loc, 329 ip=ip, 330 ) 331