1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4"""Model classes representing a tensor comprehension. 5 6These classes model the language more at an AST level as evaluated. Reasoning 7about it typically involves processing this form into config objects that 8represent actual op definitions (i.e. YAML). 9""" 10 11from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple 12from enum import Enum 13 14from ..... import ir as _ir 15from .affine import * 16from .scalar_expr import * 17from .types import * 18from .yaml_helper import * 19 20############################################################################### 21# Tensor expression nodes. 22############################################################################### 23 24 25class TensorExpression: 26 """An expression that can appear on the RHS of a comprehension.""" 27 28 def to_scalar_expression(self) -> ScalarExpression: 29 raise NotImplementedError() 30 31 def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): 32 """Visits all tensor expression reachable by the expression.""" 33 callback(self) 34 35 def collect_dim_uses(self, uses: Set["DimDef"]): 36 """Collects all DimDefs reachable through this expression.""" 37 38 def visit_dim_def(dim_def: AffineExprDef): 39 if isinstance(dim_def, DimDef): 40 uses.add(dim_def) 41 42 def visit_affine_exprs(expr: "TensorExpression"): 43 if isinstance(expr, TensorUse): 44 for ind in expr.indices: 45 ind.visit_affine_exprs(visit_dim_def) 46 if isinstance(expr, TensorReduceFn): 47 for ind in expr.reduce_fn.reduce_dims: 48 ind.visit_affine_exprs(visit_dim_def) 49 50 self.visit_tensor_exprs(visit_affine_exprs) 51 52 def collect_tensor_uses(self, uses: Set["TensorUse"]): 53 """Collects all TensorUses reachable through this expression.""" 54 55 def visit_tensor_use(expr: "TensorExpression"): 56 if isinstance(expr, TensorUse): 57 uses.add(expr) 58 59 self.visit_tensor_exprs(visit_tensor_use) 60 61 def collect_indices(self, indices: Set["index"]): 62 """Collects all index accesses reachable through this expression.""" 63 64 def visit_index(expr: "TensorExpression"): 65 if isinstance(expr, index): 66 indices.add(expr) 67 68 self.visit_tensor_exprs(visit_index) 69 70 def collect_scalar_uses(self, uses: Set["ScalarDef"]): 71 """Collects all ScalarDefs reachable through this expression.""" 72 73 def visit_scalar_def(expr: "TensorExpression"): 74 if isinstance(expr, ScalarDef): 75 uses.add(expr) 76 77 self.visit_tensor_exprs(visit_scalar_def) 78 79 def __add__(self, rhs: "TensorExpression") -> "TensorExpression": 80 return BinaryFn.add(self, rhs) 81 82 def __mul__(self, rhs) -> "TensorExpression": 83 return BinaryFn.mul(self, rhs) 84 85 def __sub__(self, rhs) -> "TensorExpression": 86 return BinaryFn.sub(self, rhs) 87 88 def __truediv__(self, rhs) -> "TensorExpression": 89 return BinaryFn.div(self, rhs) 90 91 def __hash__(self): 92 return hash(id(self)) 93 94 95class TensorUse(TensorExpression): 96 """A used tensor represented by its (tensor_name, indices). 97 98 Note that forming a comprehension via direct assignment is performed through 99 __setitem__ on the TensorDef level. However, performing a reduction with 100 compound ops (+=, *=, etc) is done by doing a: 101 TensorDef.__getitem__ 102 TensorUse.__iadd__ 103 TensorDef.__setitem__ 104 """ 105 106 def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]): 107 self.operand_def = operand_def 108 self.indices = tuple(indices) 109 110 def to_scalar_expression(self) -> ScalarExpression: 111 return ScalarArg(self.tensor_name).expr() 112 113 @property 114 def tensor_name(self) -> str: 115 name = self.operand_def.name 116 assert name is not None, "TensorDef not registered with an op" 117 return name 118 119 def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: 120 # Computes the reduction dims for implicit reductions. Assumes that the rhs 121 # is the expression being reduced and self is being reduced into. Any 122 # indices referenced on the rhs and not in self are considered reduction 123 # dims and will be ordered as encountered on the rhs. 124 rhs_dims = set() 125 lhs_dims = set() 126 rhs.collect_dim_uses(rhs_dims) 127 self.collect_dim_uses(lhs_dims) 128 return rhs_dims - lhs_dims 129 130 def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": 131 return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) 132 133 def __repr__(self): 134 return ( 135 f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]" 136 ) 137 138 139class TensorFn(TensorExpression): 140 """Application of a tensor function.""" 141 142 def __init__( 143 self, 144 kind: "FunctionKind", 145 name: Optional[str], 146 operand_def: Optional["OperandDef"], 147 type_var: Optional[TypeVar], 148 args: Sequence[TensorExpression], 149 ): 150 if bool(name) + bool(operand_def) != 1: 151 raise ValueError("One of 'name', 'operand_def' must be specified") 152 self.name = name 153 self.kind = kind 154 self.operand_def = operand_def 155 self.type_var = type_var 156 self.args = args 157 158 def to_scalar_expression(self) -> ScalarExpression: 159 if self.operand_def: 160 assert self.operand_def.name, "TensorFn not registered with an op" 161 attr_name = self.operand_def.name if self.operand_def else None 162 args = [arg.to_scalar_expression() for arg in self.args] 163 return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() 164 165 def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): 166 super().visit_tensor_exprs(callback) 167 for arg in self.args: 168 arg.visit_tensor_exprs(callback) 169 170 def __repr__(self): 171 name = self.operand_def.name if self.operand_def else self.name 172 return ( 173 f"{self.kind.name}.{name}(type_var={self.type_var}, " 174 f"args={', '.join(repr(a) for a in self.args)})" 175 ) 176 177 178class TensorReduceFn(TensorExpression): 179 """Application of a reduction function. 180 181 This captures the lhs (initial value) separately from the rhs. 182 """ 183 184 def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]): 185 self.reduce_use = reduce_use 186 self.lhs = None # type: Optional[TensorUse] 187 self.args = args 188 189 def to_scalar_expression(self) -> ScalarExpression: 190 if self.lhs is None: 191 raise ValueError( 192 f"Cannot scalarize a TensorReduceFn that has not been " 193 f"bound to its lhs: {self}" 194 ) 195 full_args = [self.lhs.to_scalar_expression()] + [ 196 arg.to_scalar_expression() for arg in self.args 197 ] 198 fn_name = None 199 attr_name = None 200 if self.reduce_use.binary_fn: 201 fn_name = self.reduce_use.binary_fn.fn_name 202 if self.reduce_use.binary_attr: 203 attr_name = self.reduce_use.binary_attr.operand_def.name 204 return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr() 205 206 def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): 207 for arg in self.args: 208 arg.visit_tensor_exprs(callback) 209 210 def __repr__(self): 211 return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" 212 213 214class const(TensorExpression): 215 """Returns the given constant floating point or integer value.""" 216 217 def __init__(self, value: Any): 218 with _ir.Context(): 219 if isinstance(value, float): 220 self.value = str(_ir.FloatAttr.get_f64(float(value))) 221 elif isinstance(value, int): 222 self.value = str( 223 _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)) 224 ) 225 else: 226 raise ValueError(f"const requires int or float but got {type(value)}") 227 228 def to_scalar_expression(self) -> ScalarExpression: 229 return ScalarConst(self.value).expr() 230 231 def __repr__(self): 232 return f"const({self.value})" 233 234 235class index(TensorExpression): 236 """Returns the iteration index for a given dimension name. 237 238 Resolves the given dimension name to obtain its position in the iteration 239 domain of the operation. 240 """ 241 242 def __init__(self, dim: DimDef): 243 self.dim_def = dim 244 self.dim = -1 245 246 def resolve_dimension_name(self, affine_state: AffineBuildState): 247 self.dim = affine_state.get_dim(self.dim_def.dimname) 248 249 def to_scalar_expression(self) -> ScalarExpression: 250 assert self.dim != -1, "Dimension name not resolved" 251 return ScalarIndex(self.dim).expr() 252 253 def __repr__(self): 254 return f"index({repr(self.dim)})" 255 256 257############################################################################### 258# Function types and function definitions. 259############################################################################### 260 261 262class FunctionKind(Enum): 263 UNARY = 0 264 BINARY = 1 265 TERNARY = 2 266 TYPE = 3 267 268 269class UnaryFnType: 270 """Unary function. 271 272 A unary function takes one tensor expression and returns the 273 function evaluation result. 274 """ 275 276 def __init__(self, fn_name: str): 277 self.fn_name = fn_name 278 279 def __call__(self, arg: TensorExpression) -> "TensorFn": 280 return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) 281 282 def __repr__(self): 283 return f"{self.fn_name}" 284 285 286class UnaryFn: 287 """Unary function namespace.""" 288 289 exp = UnaryFnType("exp") 290 log = UnaryFnType("log") 291 abs = UnaryFnType("abs") 292 ceil = UnaryFnType("ceil") 293 floor = UnaryFnType("floor") 294 negf = UnaryFnType("negf") 295 reciprocal = UnaryFnType("reciprocal") 296 round = UnaryFnType("round") 297 sqrt = UnaryFnType("sqrt") 298 rsqrt = UnaryFnType("rsqrt") 299 square = UnaryFnType("square") 300 tanh = UnaryFnType("tanh") 301 erf = UnaryFnType("erf") 302 303 304class BinaryFnType: 305 """Binary function. 306 307 A binary function takes two tensor expressions and returns the 308 function evaluation result. 309 """ 310 311 def __init__(self, fn_name: str): 312 self.fn_name = fn_name 313 314 def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn": 315 return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) 316 317 def __repr__(self): 318 return f"{self.fn_name}" 319 320 321class BinaryFn: 322 """Binary function namespace. 323 324 As the integer types are signless, signedness is implement by different 325 functions that treat integers as signed or unsigned values. 326 327 Examples: 328 - max -> `arith.MaxSIOp` 329 - max_unsigned -> `arith.MaxUIOp` 330 """ 331 332 add = BinaryFnType("add") 333 sub = BinaryFnType("sub") 334 mul = BinaryFnType("mul") 335 div = BinaryFnType("div") 336 div_unsigned = BinaryFnType("div_unsigned") 337 max_signed = BinaryFnType("max_signed") 338 min_signed = BinaryFnType("min_signed") 339 max_unsigned = BinaryFnType("max_unsigned") 340 min_unsigned = BinaryFnType("min_unsigned") 341 powf = BinaryFnType("powf") 342 343 344class TernaryFnType: 345 """Ternary function. 346 347 A ternary function takes three tensor expressions and returns the 348 function evaluation result. 349 """ 350 351 def __init__(self, fn_name: str): 352 self.fn_name = fn_name 353 354 def __call__( 355 self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression 356 ) -> "TensorFn": 357 return TensorFn( 358 FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2] 359 ) 360 361 def __repr__(self): 362 return f"{self.fn_name}" 363 364 365class TernaryFn: 366 """Ternary function namespace.""" 367 368 select = TernaryFnType("select") 369 370 371class TypeFnType: 372 """Type conversion function. 373 374 A type conversion function takes a target type and a tensor expression and 375 returns the casted tensor expression. 376 """ 377 378 def __init__(self, fn_name: str): 379 self.fn_name = fn_name 380 381 def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": 382 return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) 383 384 def __repr__(self): 385 return f"{self.fn_name}" 386 387 388class TypeFn: 389 """Type conversion function namespace. 390 391 As the integer types are signless, signedness is implement by different cast 392 functions that treat integers as signed (`cast_signed`) or unsigned 393 (`cast_unsigned`) values. 394 395 Examples: 396 - cast_signed(I32 -> I64) -> `arith.ExtSIOp` 397 - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` 398 """ 399 400 cast_signed = TypeFnType("cast_signed") 401 cast_unsigned = TypeFnType("cast_unsigned") 402 403 404class ReduceFnUse: 405 """Reduction function use. 406 407 A reduction use specifies the reduction function and dimensions. 408 """ 409 410 def __init__( 411 self, 412 binary_fn: Optional[BinaryFnType], 413 binary_attr: Optional["BinaryFnAttrDef"], 414 *reduce_dims: DimDef, 415 ): 416 if bool(binary_fn) + bool(binary_attr) != 1: 417 raise ValueError("One of 'binary_fn', 'binary_attr' must be specified") 418 self.binary_fn = binary_fn 419 self.binary_attr = binary_attr 420 self.reduce_dims = reduce_dims 421 422 def __call__(self, *args: TensorExpression) -> "TensorReduceFn": 423 return TensorReduceFn(self, args) 424 425 def __repr__(self): 426 fn = self.binary_fn if self.binary_fn else self.binary_attr 427 return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})" 428 429 430class ReduceFnType: 431 """Reduction function. 432 433 A binary function that reduces its RHS into its LHS. 434 """ 435 436 def __init__(self, binary_fn: BinaryFnType): 437 if not isinstance(binary_fn, BinaryFnType): 438 raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}") 439 self.binary_fn = binary_fn 440 441 def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: 442 return ReduceFnUse(self.binary_fn, None, *reduce_dims) 443 444 def __repr__(self): 445 return f"reduce_{repr(self.binary_fn)}" 446 447 448class ReduceFn: 449 add = ReduceFnType(BinaryFn.add) 450 mul = ReduceFnType(BinaryFn.mul) 451 max_signed = ReduceFnType(BinaryFn.max_signed) 452 min_signed = ReduceFnType(BinaryFn.min_signed) 453 max_unsigned = ReduceFnType(BinaryFn.max_unsigned) 454 min_unsigned = ReduceFnType(BinaryFn.min_unsigned) 455 456 457############################################################################### 458# Operand definitions. 459############################################################################### 460 461 462class OperandKind(Enum): 463 INPUT_TENSOR = 0 464 SCALAR = 1 465 OUTPUT_TENSOR = 2 466 INDEX_ATTR = 3 467 UNARY_FN_ATTR = 4 468 BINARY_FN_ATTR = 5 469 TERNARY_FN_ATTR = 6 470 TYPE_FN_ATTR = 7 471 472 473class OperandDef: 474 """Definition of an operand passed to an operation. 475 476 Keep the meta information of Tensor, Scalar, and Attribute operands and 477 provide the shared registration functionality. 478 """ 479 480 def __init__( 481 self, 482 kind: OperandKind, 483 type_var: Optional[TypeVar] = None, 484 size_exprs: Optional[Sequence[AffineExprDef]] = None, 485 index_dims: Optional[Sequence[DimDef]] = None, 486 default_indices: Optional[Sequence[int]] = None, 487 default_fn: Optional[str] = None, 488 ): 489 if type_var and not isinstance(type_var, TypeVar): 490 raise ValueError(f"OperandDef requires a TypeVar but got {repr(type_var)}") 491 self.owner = None # type: Optional["LinalgOpDef"] 492 self.type_var = type_var 493 self.size_exprs = size_exprs 494 self.index_dims = index_dims 495 self.default_indices = default_indices 496 self.default_fn = default_fn 497 self.kind = kind 498 self.name = None # type: Optional[str] 499 self.registered_index = -1 # type: int 500 501 def attach(self, index: int, name: str, owner: "LinalgOpDef"): 502 if self.owner: 503 raise ValueError(f"OperandDef already registered with an op: {self}") 504 self.registered_index = index 505 self.name = name 506 self.owner = owner 507 508 def is_input(self) -> bool: 509 return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR 510 511 def is_tensor(self) -> bool: 512 return ( 513 self.kind == OperandKind.INPUT_TENSOR 514 or self.kind == OperandKind.OUTPUT_TENSOR 515 ) 516 517 def is_attribute(self) -> bool: 518 return ( 519 self.kind == OperandKind.INDEX_ATTR 520 or self.kind == OperandKind.UNARY_FN_ATTR 521 or self.kind == OperandKind.BINARY_FN_ATTR 522 or self.kind == OperandKind.TERNARY_FN_ATTR 523 or self.kind == OperandKind.TYPE_FN_ATTR 524 ) 525 526 def __hash__(self): 527 return hash(id(self)) 528 529 def __repr__(self): 530 return ( 531 f"{self.name}:OperandDef(kind={self.kind.name}, " 532 f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, " 533 f"index_dims={self.index_dims}, " 534 f"default_indices={self.default_indices}, " 535 f"default_fn={self.default_fn})" 536 ) 537 538 539class TensorDef: 540 """Tensor operand definition. 541 542 Tensor operands are indexed using the associated indexing_map when forwarded 543 to the body of the structured op. A unique name identifies the tensor operands 544 and an index determines their position in the operation's parameter list. A 545 tensor definition takes type, a shape, and an optional flag to mark output 546 tensors. Additionally, a tuple of index dimensions may be used to map the 547 tensor to the loop dimensions of the operation. This mapping is needed to 548 compute the indexing map of shape-only tensors that have no uses. 549 """ 550 551 def __init__( 552 self, 553 type_var: TypeVar, 554 *shape: AffineExprDef, 555 index_dims: Optional[Sequence[DimDef]] = None, 556 output: bool = False, 557 ): 558 if index_dims and len(shape) != len(index_dims): 559 raise ValueError( 560 f"Expected the shape rank {len(shape)} to match the " 561 f"number of index_dims {len(index_dims)}" 562 ) 563 if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): 564 raise ValueError( 565 f"TensorDef requires index dims of type DimDef but " f"got {index_dims}" 566 ) 567 kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR 568 self.operand_def = OperandDef( 569 kind, type_var=type_var, size_exprs=shape, index_dims=index_dims 570 ) 571 572 def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: 573 assert self.operand_def.owner, "TensorDef is not registered with an op" 574 state = AffineBuildState( 575 global_state=self.operand_def.owner._affine_state, allow_new_symbols=False 576 ) 577 if not isinstance(dims, tuple): 578 dims = (dims,) # Handle single subscript case. 579 # Special case: (None) is a 0d-scalar use. 580 if dims == (None,): 581 dims = () 582 583 exprs = [] 584 for expr_def in dims: 585 if not isinstance(expr_def, AffineExprDef): 586 raise KeyError( 587 "A TensorDef can only be subscripted by a tuple of affine dims" 588 ) 589 exprs.append(expr_def) 590 return TensorUse(self.operand_def, exprs) 591 592 def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): 593 """Creates a new 1:1 comprehension by binding this tensor to an expression. 594 595 Note that due to the way assignment works in Python, we have to capture 596 direct assignment as a setitem on the TensorDef. 597 """ 598 if not isinstance(value, TensorExpression): 599 raise ValueError( 600 f"Only TensorExpressions can be assigned to TensorDefs. " 601 f"Got: {repr(value)}" 602 ) 603 use = self[dims] 604 comp = Comprehension((use, value)) 605 self.operand_def.owner.comprehensions.append(comp) 606 607 608class ScalarDef(TensorExpression): 609 """Scalar operand definition. 610 611 Scalar operands are forwarded to the body of the structured op as they are. 612 A unique name identifies the scalars and an index determines their position in 613 the operation's parameter list. 614 """ 615 616 def __init__(self, type_var: TypeVar): 617 self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) 618 619 @property 620 def scalar_name(self) -> str: 621 name = self.operand_def.name 622 assert name is not None, "ScalarDef not registered with an op" 623 return name 624 625 def to_scalar_expression(self) -> ScalarExpression: 626 return ScalarArg(self.scalar_name).expr() 627 628 629class IndexAttrDef: 630 """Index attribute definition. 631 632 Index attributes provide a way to define and set symbols that can be used in 633 indexing expressions. Every attribute specifies a tuple of symbols that at 634 compile-time are replaced by integer values as well as their default values. 635 """ 636 637 def __init__(self, *sizes: SymbolDef, default: Sequence[int]): 638 if any(not isinstance(size, SymbolDef) for size in sizes): 639 raise ValueError( 640 f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}" 641 ) 642 if any(not isinstance(default_val, int) for default_val in default): 643 raise ValueError( 644 f"IndexAttrDef requires default values of type int " 645 f"but got {default}" 646 ) 647 if len(sizes) != len(default): 648 raise ValueError( 649 f"IndexAttrDef expects {len(sizes)} default values " 650 f"but got {len(default)}" 651 ) 652 self.operand_def = OperandDef( 653 OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default 654 ) 655 656 657class UnaryFnAttrDef: 658 """Unary function attribute definition. 659 660 Unary function attributes provide a way to make the arithmetic computation 661 parametrizable. Every attribute specifies a default unary function 662 that may be overwritten at operation instantiation time. 663 """ 664 665 def __init__(self, default: "UnaryFnType"): 666 if not isinstance(default, UnaryFnType): 667 raise ValueError( 668 f"UnaryFnAttrDef requires default of type UnaryFnType " 669 f"but got {default}" 670 ) 671 self.operand_def = OperandDef( 672 OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name 673 ) 674 675 def __call__(self, arg: TensorExpression) -> TensorFn: 676 return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) 677 678 679class BinaryFnAttrDef: 680 """Binary function attribute definition. 681 682 Binary function attributes provide a way to make the arithmetic computation 683 parametrizable. Every attribute specifies a default binary function 684 that may be overwritten at operation instantiation time. 685 """ 686 687 def __init__(self, default: "BinaryFnType"): 688 if not isinstance(default, BinaryFnType): 689 raise ValueError( 690 f"BinaryFnAttrDef requires default of type BinaryFnType " 691 f"but got {default}" 692 ) 693 self.operand_def = OperandDef( 694 OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name 695 ) 696 697 def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn: 698 return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1]) 699 700 def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: 701 return ReduceFnUse(None, self, *reduce_dims) 702 703 704class TernaryFnAttrDef: 705 """Ternary function attribute definition. 706 707 Ternary function attributes provide a way to make the arithmetic computation 708 parametrizable. Every attribute specifies a default Ternary function 709 that may be overwritten at operation instantiation time. 710 """ 711 712 def __init__(self, default: "TernaryFnType"): 713 if not isinstance(default, TernaryFnType): 714 raise ValueError( 715 f"TernaryFnAttrDef requires default of type TernaryFnType " 716 f"but got {default}" 717 ) 718 self.operand_def = OperandDef( 719 OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name 720 ) 721 722 def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn: 723 return TensorFn( 724 FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1] 725 ) 726 727 def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: 728 return ReduceFnUse(None, self, *reduce_dims) 729 730 731class TypeFnAttrDef: 732 """Type conversion function attribute definition. 733 734 Type conversion function attributes provide a way to make type conversions 735 parameterizable. Every attribute specifies a default type conversion function 736 that may be overwritten at operation instantiation time. 737 """ 738 739 def __init__(self, default: "TypeFnType"): 740 if not isinstance(default, TypeFnType): 741 raise ValueError( 742 f"TypeFnAttrDef requires default of type TypeFnType " 743 f"but got {default}" 744 ) 745 self.operand_def = OperandDef( 746 OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name 747 ) 748 749 def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: 750 return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) 751 752 753############################################################################### 754# Operation definition. 755############################################################################### 756 757 758class Comprehension: 759 """Represents a single comprehension.""" 760 761 def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): 762 self.definitions = list() # List[TensorUse] 763 self.values = list() # List[TensorExpression] 764 765 # Find the lhs to reduction rhs. 766 for assign, value in bindings: 767 if isinstance(value, TensorReduceFn): 768 if value.lhs: 769 raise ValueError(f"Reduction expression already assigns: {value}") 770 value.lhs = assign 771 self.definitions.append(assign) 772 self.values.append(value) 773 774 @property 775 def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: 776 """Gets the reduction dims for the comprehension or None.""" 777 result = set() 778 for use in self.values: 779 if isinstance(use, TensorReduceFn): 780 result.add(use.reduce_use.reduce_dims) 781 else: 782 result.add(tuple()) 783 return result 784 785 def __repr__(self): 786 if len(self.definitions) > 1: 787 defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" 788 values_repr = f"({', '.join(repr(v) for v in self.values)})" 789 else: 790 defs_repr = f"{repr(self.definitions[0])}" 791 values_repr = f"{repr(self.values[0])}" 792 793 return f"{defs_repr} = {values_repr}" 794 795 796class OpInterfaceDef: 797 """An interface that an op implements.""" 798 799 def __init__(self, cpp_name: str): 800 self.cpp_name = cpp_name 801 802 803ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") 804ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") 805FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") 806 807 808class OpDefinitionDef: 809 """A method that an op implements.""" 810 811 def __init__(self, def_name: str): 812 self.def_name = def_name 813 814 815Canonicalizer = OpDefinitionDef("hasCanonicalizer") 816 817 818class OpMetadataDef(YAMLObject): 819 """Metadata about the op (generally not behavior impacting).""" 820 821 yaml_tag = "!LinalgOpMetadata" 822 823 def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]): 824 self.name = name 825 self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name 826 self.doc = doc 827 self.implements = [] # type: List[OpInterfaceDef] 828 self.defines = [] # type: List[OpDefinitionsDef] 829 830 def to_yaml_custom_dict(self): 831 d = dict( 832 name=self.name, 833 cpp_class_name=self.cpp_class_name, 834 doc=self.doc, 835 ) 836 if self.implements: 837 d["implements"] = [intr.cpp_name for intr in self.implements] 838 if self.defines: 839 d["defines"] = [defi.def_name for defi in self.defines] 840 return d 841 842 843class LinalgOpDef: 844 """Definition of a linalg op.""" 845 846 def __init__( 847 self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None 848 ): 849 self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc) 850 self.registered_operands = dict() # type: Dict[str, OperandDef] 851 self.domain = list() # type: List[DimDef] 852 self.comprehensions = list() # type: List[Comprehension] 853 self._affine_state = AffineBuildState() 854 855 def add_operand(self, name: str, operand: OperandDef): 856 """Registers an operand.""" 857 if name in self.registered_operands: 858 raise ValueError( 859 f"The operand {name} is already registered " 860 f"to {self.registered_operands['name']}" 861 ) 862 structured_op_methods = [ 863 "inputs", 864 "outputs", 865 "result_tensors", 866 "region", 867 "iterator_types", 868 "indexing_maps", 869 "getRegionBuilder", 870 "getLibraryCallName", 871 ] 872 if operand.is_attribute() and name in structured_op_methods: 873 raise ValueError( 874 f"The attribute name {name} conflicts with a structured " 875 f"op method name" 876 ) 877 # Ensure output tensors are registered after input tensors and scalars and 878 # attributes are registered after all other operand types. 879 if operand.is_input() and any( 880 not op_def.is_input() for op_def in self.registered_operands.values() 881 ): 882 raise ValueError(f"Input {name} registered after an output or attribute") 883 if operand.kind == OperandKind.OUTPUT_TENSOR and any( 884 op_def.is_attribute() for op_def in self.registered_operands.values() 885 ): 886 raise ValueError(f"Output {name} registered after an attribute") 887 operand.attach(len(self.registered_operands), name, self) 888 self.registered_operands[name] = operand 889 890 def __repr__(self): 891 lines = [f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"] 892 for name, operand in self.registered_operands.items(): 893 lines.append(f" {operand}") 894 if self.comprehensions: 895 lines[-1] += " {" 896 for comprehension in self.comprehensions: 897 lines.append(f" {comprehension}") 898 lines.append("}") 899 return "\n".join(lines) 900