xref: /llvm-project/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py (revision de3e9d4138abeb92428bba5014af2f3d9ac21323)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4"""Model classes representing a tensor comprehension.
5
6These classes model the language more at an AST level as evaluated. Reasoning
7about it typically involves processing this form into config objects that
8represent actual op definitions (i.e. YAML).
9"""
10
11from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
12from enum import Enum
13
14from ..... import ir as _ir
15from .affine import *
16from .scalar_expr import *
17from .types import *
18from .yaml_helper import *
19
20###############################################################################
21# Tensor expression nodes.
22###############################################################################
23
24
25class TensorExpression:
26    """An expression that can appear on the RHS of a comprehension."""
27
28    def to_scalar_expression(self) -> ScalarExpression:
29        raise NotImplementedError()
30
31    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
32        """Visits all tensor expression reachable by the expression."""
33        callback(self)
34
35    def collect_dim_uses(self, uses: Set["DimDef"]):
36        """Collects all DimDefs reachable through this expression."""
37
38        def visit_dim_def(dim_def: AffineExprDef):
39            if isinstance(dim_def, DimDef):
40                uses.add(dim_def)
41
42        def visit_affine_exprs(expr: "TensorExpression"):
43            if isinstance(expr, TensorUse):
44                for ind in expr.indices:
45                    ind.visit_affine_exprs(visit_dim_def)
46            if isinstance(expr, TensorReduceFn):
47                for ind in expr.reduce_fn.reduce_dims:
48                    ind.visit_affine_exprs(visit_dim_def)
49
50        self.visit_tensor_exprs(visit_affine_exprs)
51
52    def collect_tensor_uses(self, uses: Set["TensorUse"]):
53        """Collects all TensorUses reachable through this expression."""
54
55        def visit_tensor_use(expr: "TensorExpression"):
56            if isinstance(expr, TensorUse):
57                uses.add(expr)
58
59        self.visit_tensor_exprs(visit_tensor_use)
60
61    def collect_indices(self, indices: Set["index"]):
62        """Collects all index accesses reachable through this expression."""
63
64        def visit_index(expr: "TensorExpression"):
65            if isinstance(expr, index):
66                indices.add(expr)
67
68        self.visit_tensor_exprs(visit_index)
69
70    def collect_scalar_uses(self, uses: Set["ScalarDef"]):
71        """Collects all ScalarDefs reachable through this expression."""
72
73        def visit_scalar_def(expr: "TensorExpression"):
74            if isinstance(expr, ScalarDef):
75                uses.add(expr)
76
77        self.visit_tensor_exprs(visit_scalar_def)
78
79    def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
80        return BinaryFn.add(self, rhs)
81
82    def __mul__(self, rhs) -> "TensorExpression":
83        return BinaryFn.mul(self, rhs)
84
85    def __sub__(self, rhs) -> "TensorExpression":
86        return BinaryFn.sub(self, rhs)
87
88    def __truediv__(self, rhs) -> "TensorExpression":
89        return BinaryFn.div(self, rhs)
90
91    def __hash__(self):
92        return hash(id(self))
93
94
95class TensorUse(TensorExpression):
96    """A used tensor represented by its (tensor_name, indices).
97
98    Note that forming a comprehension via direct assignment is performed through
99    __setitem__ on the TensorDef level. However, performing a reduction with
100    compound ops (+=, *=, etc) is done by doing a:
101      TensorDef.__getitem__
102      TensorUse.__iadd__
103      TensorDef.__setitem__
104    """
105
106    def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]):
107        self.operand_def = operand_def
108        self.indices = tuple(indices)
109
110    def to_scalar_expression(self) -> ScalarExpression:
111        return ScalarArg(self.tensor_name).expr()
112
113    @property
114    def tensor_name(self) -> str:
115        name = self.operand_def.name
116        assert name is not None, "TensorDef not registered with an op"
117        return name
118
119    def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
120        # Computes the reduction dims for implicit reductions. Assumes that the rhs
121        # is the expression being reduced and self is being reduced into. Any
122        # indices referenced on the rhs and not in self are considered reduction
123        # dims and will be ordered as encountered on the rhs.
124        rhs_dims = set()
125        lhs_dims = set()
126        rhs.collect_dim_uses(rhs_dims)
127        self.collect_dim_uses(lhs_dims)
128        return rhs_dims - lhs_dims
129
130    def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
131        return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
132
133    def __repr__(self):
134        return (
135            f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]"
136        )
137
138
139class TensorFn(TensorExpression):
140    """Application of a tensor function."""
141
142    def __init__(
143        self,
144        kind: "FunctionKind",
145        name: Optional[str],
146        operand_def: Optional["OperandDef"],
147        type_var: Optional[TypeVar],
148        args: Sequence[TensorExpression],
149    ):
150        if bool(name) + bool(operand_def) != 1:
151            raise ValueError("One of 'name', 'operand_def' must be specified")
152        self.name = name
153        self.kind = kind
154        self.operand_def = operand_def
155        self.type_var = type_var
156        self.args = args
157
158    def to_scalar_expression(self) -> ScalarExpression:
159        if self.operand_def:
160            assert self.operand_def.name, "TensorFn not registered with an op"
161        attr_name = self.operand_def.name if self.operand_def else None
162        args = [arg.to_scalar_expression() for arg in self.args]
163        return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
164
165    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
166        super().visit_tensor_exprs(callback)
167        for arg in self.args:
168            arg.visit_tensor_exprs(callback)
169
170    def __repr__(self):
171        name = self.operand_def.name if self.operand_def else self.name
172        return (
173            f"{self.kind.name}.{name}(type_var={self.type_var}, "
174            f"args={', '.join(repr(a) for a in self.args)})"
175        )
176
177
178class TensorReduceFn(TensorExpression):
179    """Application of a reduction function.
180
181    This captures the lhs (initial value) separately from the rhs.
182    """
183
184    def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]):
185        self.reduce_use = reduce_use
186        self.lhs = None  # type: Optional[TensorUse]
187        self.args = args
188
189    def to_scalar_expression(self) -> ScalarExpression:
190        if self.lhs is None:
191            raise ValueError(
192                f"Cannot scalarize a TensorReduceFn that has not been "
193                f"bound to its lhs: {self}"
194            )
195        full_args = [self.lhs.to_scalar_expression()] + [
196            arg.to_scalar_expression() for arg in self.args
197        ]
198        fn_name = None
199        attr_name = None
200        if self.reduce_use.binary_fn:
201            fn_name = self.reduce_use.binary_fn.fn_name
202        if self.reduce_use.binary_attr:
203            attr_name = self.reduce_use.binary_attr.operand_def.name
204        return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr()
205
206    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
207        for arg in self.args:
208            arg.visit_tensor_exprs(callback)
209
210    def __repr__(self):
211        return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
212
213
214class const(TensorExpression):
215    """Returns the given constant floating point or integer value."""
216
217    def __init__(self, value: Any):
218        with _ir.Context():
219            if isinstance(value, float):
220                self.value = str(_ir.FloatAttr.get_f64(float(value)))
221            elif isinstance(value, int):
222                self.value = str(
223                    _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))
224                )
225            else:
226                raise ValueError(f"const requires int or float but got {type(value)}")
227
228    def to_scalar_expression(self) -> ScalarExpression:
229        return ScalarConst(self.value).expr()
230
231    def __repr__(self):
232        return f"const({self.value})"
233
234
235class index(TensorExpression):
236    """Returns the iteration index for a given dimension name.
237
238    Resolves the given dimension name to obtain its position in the iteration
239    domain of the operation.
240    """
241
242    def __init__(self, dim: DimDef):
243        self.dim_def = dim
244        self.dim = -1
245
246    def resolve_dimension_name(self, affine_state: AffineBuildState):
247        self.dim = affine_state.get_dim(self.dim_def.dimname)
248
249    def to_scalar_expression(self) -> ScalarExpression:
250        assert self.dim != -1, "Dimension name not resolved"
251        return ScalarIndex(self.dim).expr()
252
253    def __repr__(self):
254        return f"index({repr(self.dim)})"
255
256
257###############################################################################
258# Function types and function definitions.
259###############################################################################
260
261
262class FunctionKind(Enum):
263    UNARY = 0
264    BINARY = 1
265    TERNARY = 2
266    TYPE = 3
267
268
269class UnaryFnType:
270    """Unary function.
271
272    A unary function takes one tensor expression and returns the
273    function evaluation result.
274    """
275
276    def __init__(self, fn_name: str):
277        self.fn_name = fn_name
278
279    def __call__(self, arg: TensorExpression) -> "TensorFn":
280        return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
281
282    def __repr__(self):
283        return f"{self.fn_name}"
284
285
286class UnaryFn:
287    """Unary function namespace."""
288
289    exp = UnaryFnType("exp")
290    log = UnaryFnType("log")
291    abs = UnaryFnType("abs")
292    ceil = UnaryFnType("ceil")
293    floor = UnaryFnType("floor")
294    negf = UnaryFnType("negf")
295    reciprocal = UnaryFnType("reciprocal")
296    round = UnaryFnType("round")
297    sqrt = UnaryFnType("sqrt")
298    rsqrt = UnaryFnType("rsqrt")
299    square = UnaryFnType("square")
300    tanh = UnaryFnType("tanh")
301    erf = UnaryFnType("erf")
302
303
304class BinaryFnType:
305    """Binary function.
306
307    A binary function takes two tensor expressions and returns the
308    function evaluation result.
309    """
310
311    def __init__(self, fn_name: str):
312        self.fn_name = fn_name
313
314    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn":
315        return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
316
317    def __repr__(self):
318        return f"{self.fn_name}"
319
320
321class BinaryFn:
322    """Binary function namespace.
323
324    As the integer types are signless, signedness is implement by different
325    functions that treat integers as signed or unsigned values.
326
327    Examples:
328    - max -> `arith.MaxSIOp`
329    - max_unsigned -> `arith.MaxUIOp`
330    """
331
332    add = BinaryFnType("add")
333    sub = BinaryFnType("sub")
334    mul = BinaryFnType("mul")
335    div = BinaryFnType("div")
336    div_unsigned = BinaryFnType("div_unsigned")
337    max_signed = BinaryFnType("max_signed")
338    min_signed = BinaryFnType("min_signed")
339    max_unsigned = BinaryFnType("max_unsigned")
340    min_unsigned = BinaryFnType("min_unsigned")
341    powf = BinaryFnType("powf")
342
343
344class TernaryFnType:
345    """Ternary function.
346
347    A ternary function takes three tensor expressions and returns the
348    function evaluation result.
349    """
350
351    def __init__(self, fn_name: str):
352        self.fn_name = fn_name
353
354    def __call__(
355        self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression
356    ) -> "TensorFn":
357        return TensorFn(
358            FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2]
359        )
360
361    def __repr__(self):
362        return f"{self.fn_name}"
363
364
365class TernaryFn:
366    """Ternary function namespace."""
367
368    select = TernaryFnType("select")
369
370
371class TypeFnType:
372    """Type conversion function.
373
374    A type conversion function takes a target type and a tensor expression and
375    returns the casted tensor expression.
376    """
377
378    def __init__(self, fn_name: str):
379        self.fn_name = fn_name
380
381    def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
382        return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
383
384    def __repr__(self):
385        return f"{self.fn_name}"
386
387
388class TypeFn:
389    """Type conversion function namespace.
390
391    As the integer types are signless, signedness is implement by different cast
392    functions that treat integers as signed (`cast_signed`) or unsigned
393    (`cast_unsigned`) values.
394
395    Examples:
396    - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
397    - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
398    """
399
400    cast_signed = TypeFnType("cast_signed")
401    cast_unsigned = TypeFnType("cast_unsigned")
402
403
404class ReduceFnUse:
405    """Reduction function use.
406
407    A reduction use specifies the reduction function and dimensions.
408    """
409
410    def __init__(
411        self,
412        binary_fn: Optional[BinaryFnType],
413        binary_attr: Optional["BinaryFnAttrDef"],
414        *reduce_dims: DimDef,
415    ):
416        if bool(binary_fn) + bool(binary_attr) != 1:
417            raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
418        self.binary_fn = binary_fn
419        self.binary_attr = binary_attr
420        self.reduce_dims = reduce_dims
421
422    def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
423        return TensorReduceFn(self, args)
424
425    def __repr__(self):
426        fn = self.binary_fn if self.binary_fn else self.binary_attr
427        return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})"
428
429
430class ReduceFnType:
431    """Reduction function.
432
433    A binary function that reduces its RHS into its LHS.
434    """
435
436    def __init__(self, binary_fn: BinaryFnType):
437        if not isinstance(binary_fn, BinaryFnType):
438            raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
439        self.binary_fn = binary_fn
440
441    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
442        return ReduceFnUse(self.binary_fn, None, *reduce_dims)
443
444    def __repr__(self):
445        return f"reduce_{repr(self.binary_fn)}"
446
447
448class ReduceFn:
449    add = ReduceFnType(BinaryFn.add)
450    mul = ReduceFnType(BinaryFn.mul)
451    max_signed = ReduceFnType(BinaryFn.max_signed)
452    min_signed = ReduceFnType(BinaryFn.min_signed)
453    max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
454    min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
455
456
457###############################################################################
458# Operand definitions.
459###############################################################################
460
461
462class OperandKind(Enum):
463    INPUT_TENSOR = 0
464    SCALAR = 1
465    OUTPUT_TENSOR = 2
466    INDEX_ATTR = 3
467    UNARY_FN_ATTR = 4
468    BINARY_FN_ATTR = 5
469    TERNARY_FN_ATTR = 6
470    TYPE_FN_ATTR = 7
471
472
473class OperandDef:
474    """Definition of an operand passed to an operation.
475
476    Keep the meta information of Tensor, Scalar, and Attribute operands and
477    provide the shared registration functionality.
478    """
479
480    def __init__(
481        self,
482        kind: OperandKind,
483        type_var: Optional[TypeVar] = None,
484        size_exprs: Optional[Sequence[AffineExprDef]] = None,
485        index_dims: Optional[Sequence[DimDef]] = None,
486        default_indices: Optional[Sequence[int]] = None,
487        default_fn: Optional[str] = None,
488    ):
489        if type_var and not isinstance(type_var, TypeVar):
490            raise ValueError(f"OperandDef requires a TypeVar but got {repr(type_var)}")
491        self.owner = None  # type: Optional["LinalgOpDef"]
492        self.type_var = type_var
493        self.size_exprs = size_exprs
494        self.index_dims = index_dims
495        self.default_indices = default_indices
496        self.default_fn = default_fn
497        self.kind = kind
498        self.name = None  # type: Optional[str]
499        self.registered_index = -1  # type: int
500
501    def attach(self, index: int, name: str, owner: "LinalgOpDef"):
502        if self.owner:
503            raise ValueError(f"OperandDef already registered with an op: {self}")
504        self.registered_index = index
505        self.name = name
506        self.owner = owner
507
508    def is_input(self) -> bool:
509        return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR
510
511    def is_tensor(self) -> bool:
512        return (
513            self.kind == OperandKind.INPUT_TENSOR
514            or self.kind == OperandKind.OUTPUT_TENSOR
515        )
516
517    def is_attribute(self) -> bool:
518        return (
519            self.kind == OperandKind.INDEX_ATTR
520            or self.kind == OperandKind.UNARY_FN_ATTR
521            or self.kind == OperandKind.BINARY_FN_ATTR
522            or self.kind == OperandKind.TERNARY_FN_ATTR
523            or self.kind == OperandKind.TYPE_FN_ATTR
524        )
525
526    def __hash__(self):
527        return hash(id(self))
528
529    def __repr__(self):
530        return (
531            f"{self.name}:OperandDef(kind={self.kind.name}, "
532            f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, "
533            f"index_dims={self.index_dims}, "
534            f"default_indices={self.default_indices}, "
535            f"default_fn={self.default_fn})"
536        )
537
538
539class TensorDef:
540    """Tensor operand definition.
541
542    Tensor operands are indexed using the associated indexing_map when forwarded
543    to the body of the structured op. A unique name identifies the tensor operands
544    and an index determines their position in the operation's parameter list. A
545    tensor definition takes type, a shape, and an optional flag to mark output
546    tensors. Additionally, a tuple of index dimensions may be used to map the
547    tensor to the loop dimensions of the operation. This mapping is needed to
548    compute the indexing map of shape-only tensors that have no uses.
549    """
550
551    def __init__(
552        self,
553        type_var: TypeVar,
554        *shape: AffineExprDef,
555        index_dims: Optional[Sequence[DimDef]] = None,
556        output: bool = False,
557    ):
558        if index_dims and len(shape) != len(index_dims):
559            raise ValueError(
560                f"Expected the shape rank {len(shape)} to match the "
561                f"number of index_dims {len(index_dims)}"
562            )
563        if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
564            raise ValueError(
565                f"TensorDef requires index dims of type DimDef but " f"got {index_dims}"
566            )
567        kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
568        self.operand_def = OperandDef(
569            kind, type_var=type_var, size_exprs=shape, index_dims=index_dims
570        )
571
572    def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
573        assert self.operand_def.owner, "TensorDef is not registered with an op"
574        state = AffineBuildState(
575            global_state=self.operand_def.owner._affine_state, allow_new_symbols=False
576        )
577        if not isinstance(dims, tuple):
578            dims = (dims,)  # Handle single subscript case.
579        # Special case: (None) is a 0d-scalar use.
580        if dims == (None,):
581            dims = ()
582
583        exprs = []
584        for expr_def in dims:
585            if not isinstance(expr_def, AffineExprDef):
586                raise KeyError(
587                    "A TensorDef can only be subscripted by a tuple of affine dims"
588                )
589            exprs.append(expr_def)
590        return TensorUse(self.operand_def, exprs)
591
592    def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
593        """Creates a new 1:1 comprehension by binding this tensor to an expression.
594
595        Note that due to the way assignment works in Python, we have to capture
596        direct assignment as a setitem on the TensorDef.
597        """
598        if not isinstance(value, TensorExpression):
599            raise ValueError(
600                f"Only TensorExpressions can be assigned to TensorDefs. "
601                f"Got: {repr(value)}"
602            )
603        use = self[dims]
604        comp = Comprehension((use, value))
605        self.operand_def.owner.comprehensions.append(comp)
606
607
608class ScalarDef(TensorExpression):
609    """Scalar operand definition.
610
611    Scalar operands are forwarded to the body of the structured op as they are.
612    A unique name identifies the scalars and an index determines their position in
613    the operation's parameter list.
614    """
615
616    def __init__(self, type_var: TypeVar):
617        self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
618
619    @property
620    def scalar_name(self) -> str:
621        name = self.operand_def.name
622        assert name is not None, "ScalarDef not registered with an op"
623        return name
624
625    def to_scalar_expression(self) -> ScalarExpression:
626        return ScalarArg(self.scalar_name).expr()
627
628
629class IndexAttrDef:
630    """Index attribute definition.
631
632    Index attributes provide a way to define and set symbols that can be used in
633    indexing expressions. Every attribute specifies a tuple of symbols that at
634    compile-time are replaced by integer values as well as their default values.
635    """
636
637    def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
638        if any(not isinstance(size, SymbolDef) for size in sizes):
639            raise ValueError(
640                f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}"
641            )
642        if any(not isinstance(default_val, int) for default_val in default):
643            raise ValueError(
644                f"IndexAttrDef requires default values of type int "
645                f"but got {default}"
646            )
647        if len(sizes) != len(default):
648            raise ValueError(
649                f"IndexAttrDef expects {len(sizes)} default values "
650                f"but got {len(default)}"
651            )
652        self.operand_def = OperandDef(
653            OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default
654        )
655
656
657class UnaryFnAttrDef:
658    """Unary function attribute definition.
659
660    Unary function attributes provide a way to make the arithmetic computation
661    parametrizable. Every attribute specifies a default unary function
662    that may be overwritten at operation instantiation time.
663    """
664
665    def __init__(self, default: "UnaryFnType"):
666        if not isinstance(default, UnaryFnType):
667            raise ValueError(
668                f"UnaryFnAttrDef requires default of type UnaryFnType "
669                f"but got {default}"
670            )
671        self.operand_def = OperandDef(
672            OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name
673        )
674
675    def __call__(self, arg: TensorExpression) -> TensorFn:
676        return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
677
678
679class BinaryFnAttrDef:
680    """Binary function attribute definition.
681
682    Binary function attributes provide a way to make the arithmetic computation
683    parametrizable. Every attribute specifies a default binary function
684    that may be overwritten at operation instantiation time.
685    """
686
687    def __init__(self, default: "BinaryFnType"):
688        if not isinstance(default, BinaryFnType):
689            raise ValueError(
690                f"BinaryFnAttrDef requires default of type BinaryFnType "
691                f"but got {default}"
692            )
693        self.operand_def = OperandDef(
694            OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name
695        )
696
697    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
698        return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1])
699
700    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
701        return ReduceFnUse(None, self, *reduce_dims)
702
703
704class TernaryFnAttrDef:
705    """Ternary function attribute definition.
706
707    Ternary function attributes provide a way to make the arithmetic computation
708    parametrizable. Every attribute specifies a default Ternary function
709    that may be overwritten at operation instantiation time.
710    """
711
712    def __init__(self, default: "TernaryFnType"):
713        if not isinstance(default, TernaryFnType):
714            raise ValueError(
715                f"TernaryFnAttrDef requires default of type TernaryFnType "
716                f"but got {default}"
717            )
718        self.operand_def = OperandDef(
719            OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name
720        )
721
722    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
723        return TensorFn(
724            FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1]
725        )
726
727    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
728        return ReduceFnUse(None, self, *reduce_dims)
729
730
731class TypeFnAttrDef:
732    """Type conversion function attribute definition.
733
734    Type conversion function attributes provide a way to make type conversions
735    parameterizable. Every attribute specifies a default type conversion function
736    that may be overwritten at operation instantiation time.
737    """
738
739    def __init__(self, default: "TypeFnType"):
740        if not isinstance(default, TypeFnType):
741            raise ValueError(
742                f"TypeFnAttrDef requires default of type TypeFnType "
743                f"but got {default}"
744            )
745        self.operand_def = OperandDef(
746            OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name
747        )
748
749    def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
750        return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
751
752
753###############################################################################
754# Operation definition.
755###############################################################################
756
757
758class Comprehension:
759    """Represents a single comprehension."""
760
761    def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
762        self.definitions = list()  # List[TensorUse]
763        self.values = list()  # List[TensorExpression]
764
765        # Find the lhs to reduction rhs.
766        for assign, value in bindings:
767            if isinstance(value, TensorReduceFn):
768                if value.lhs:
769                    raise ValueError(f"Reduction expression already assigns: {value}")
770                value.lhs = assign
771            self.definitions.append(assign)
772            self.values.append(value)
773
774    @property
775    def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
776        """Gets the reduction dims for the comprehension or None."""
777        result = set()
778        for use in self.values:
779            if isinstance(use, TensorReduceFn):
780                result.add(use.reduce_use.reduce_dims)
781            else:
782                result.add(tuple())
783        return result
784
785    def __repr__(self):
786        if len(self.definitions) > 1:
787            defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
788            values_repr = f"({', '.join(repr(v) for v in self.values)})"
789        else:
790            defs_repr = f"{repr(self.definitions[0])}"
791            values_repr = f"{repr(self.values[0])}"
792
793        return f"{defs_repr} = {values_repr}"
794
795
796class OpInterfaceDef:
797    """An interface that an op implements."""
798
799    def __init__(self, cpp_name: str):
800        self.cpp_name = cpp_name
801
802
803ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
804ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
805FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
806
807
808class OpDefinitionDef:
809    """A method that an op implements."""
810
811    def __init__(self, def_name: str):
812        self.def_name = def_name
813
814
815Canonicalizer = OpDefinitionDef("hasCanonicalizer")
816
817
818class OpMetadataDef(YAMLObject):
819    """Metadata about the op (generally not behavior impacting)."""
820
821    yaml_tag = "!LinalgOpMetadata"
822
823    def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]):
824        self.name = name
825        self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
826        self.doc = doc
827        self.implements = []  # type: List[OpInterfaceDef]
828        self.defines = []  # type: List[OpDefinitionsDef]
829
830    def to_yaml_custom_dict(self):
831        d = dict(
832            name=self.name,
833            cpp_class_name=self.cpp_class_name,
834            doc=self.doc,
835        )
836        if self.implements:
837            d["implements"] = [intr.cpp_name for intr in self.implements]
838        if self.defines:
839            d["defines"] = [defi.def_name for defi in self.defines]
840        return d
841
842
843class LinalgOpDef:
844    """Definition of a linalg op."""
845
846    def __init__(
847        self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None
848    ):
849        self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
850        self.registered_operands = dict()  # type: Dict[str, OperandDef]
851        self.domain = list()  # type: List[DimDef]
852        self.comprehensions = list()  # type: List[Comprehension]
853        self._affine_state = AffineBuildState()
854
855    def add_operand(self, name: str, operand: OperandDef):
856        """Registers an operand."""
857        if name in self.registered_operands:
858            raise ValueError(
859                f"The operand {name} is already registered "
860                f"to {self.registered_operands['name']}"
861            )
862        structured_op_methods = [
863            "inputs",
864            "outputs",
865            "result_tensors",
866            "region",
867            "iterator_types",
868            "indexing_maps",
869            "getRegionBuilder",
870            "getLibraryCallName",
871        ]
872        if operand.is_attribute() and name in structured_op_methods:
873            raise ValueError(
874                f"The attribute name {name} conflicts with a structured "
875                f"op method name"
876            )
877        # Ensure output tensors are registered after input tensors and scalars and
878        # attributes are registered after all other operand types.
879        if operand.is_input() and any(
880            not op_def.is_input() for op_def in self.registered_operands.values()
881        ):
882            raise ValueError(f"Input {name} registered after an output or attribute")
883        if operand.kind == OperandKind.OUTPUT_TENSOR and any(
884            op_def.is_attribute() for op_def in self.registered_operands.values()
885        ):
886            raise ValueError(f"Output {name} registered after an attribute")
887        operand.attach(len(self.registered_operands), name, self)
888        self.registered_operands[name] = operand
889
890    def __repr__(self):
891        lines = [f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"]
892        for name, operand in self.registered_operands.items():
893            lines.append(f"  {operand}")
894        if self.comprehensions:
895            lines[-1] += " {"
896            for comprehension in self.comprehensions:
897                lines.append(f"    {comprehension}")
898            lines.append("}")
899        return "\n".join(lines)
900