xref: /llvm-project/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4"""Represents configured ops as emitted for code generation.
5
6Classes in this module generally are directly serializable to YAML for use
7by the code generator.
8
9TODO: These should just be dumb containers or serialization code but they
10currently encode too many details of how the language is interpreted. Move this
11to helpers on the comprehension objects themselves.
12"""
13
14from typing import Dict, Optional
15
16from ..... import ir as _ir
17from .comprehension import *
18from .yaml_helper import *
19
20__all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"]
21
22
23def _serialize_affine_map(affine_map: _ir.AffineMap) -> str:
24    with affine_map.context:
25        # Affine map printing/parsing is via an AffineMap attr.
26        attr = _ir.AffineMapAttr.get(affine_map)
27        return str(attr)
28
29
30class TensorUseConfig:
31    """Wrapper around a TensorUse with additional context-bound state."""
32
33    def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
34        self.tensor_use = tensor_use
35        self.indexing_map = indexing_map
36
37    def __repr__(self):
38        return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
39
40
41class OperandDefConfig(YAMLObject):
42    """Wrapper containing an operand definition with additional state."""
43
44    yaml_tag = "!LinalgOperandDefConfig"
45
46    def __init__(
47        self,
48        operand_def: OperandDef,
49        shape_map: Optional[_ir.AffineMap] = None,
50        index_attr_map: Optional[_ir.AffineMap] = None,
51    ):
52        self.operand_def = operand_def
53        self.shape_map = shape_map  # type: Optional[_ir.AffineMap]
54        self.index_attr_map = index_attr_map  # type: Optional[_ir.AffineMap]
55        self.indexing_map = None  # type: Optional[_ir.AffineMap]
56
57    @property
58    def name(self) -> str:
59        return self.operand_def.name
60
61    @property
62    def kind(self) -> OperandKind:
63        return self.operand_def.kind
64
65    @property
66    def type_var(self) -> TypeVar:
67        return self.operand_def.type_var
68
69    def to_yaml_custom_dict(self):
70        self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
71        if self.type_var:
72            self_dict["type_var"] = self.type_var.name
73        if self.shape_map:
74            self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
75        if self.index_attr_map:
76            self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
77        if self.operand_def.default_indices:
78            self_dict["default_indices"] = self.operand_def.default_indices
79        if self.operand_def.default_fn:
80            self_dict["default_fn"] = self.operand_def.default_fn
81        return self_dict
82
83    def __repr__(self):
84        return (
85            f"OperandDefConfig({self.operand_def}, "
86            f"shape_map={self.shape_map}, "
87            f"index_attr_map={self.index_attr_map}, "
88            f"indexing_map={self.indexing_map})"
89        )
90
91
92class LinalgIndexingMapsConfig(YAMLObject):
93    """Abstracts the style of indexing maps that the op exports.
94
95    Presently only static (tied to the op name) indexing maps are supported. In
96    the future, it is expected that we will have additional variants:
97      - Dynamic based on attributes
98      - Dynamic based on operands
99    Each is expected to require a different variant of specification.
100    """
101
102    yaml_tag = "!LinalgIndexingMapsConfig"
103
104    def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
105        self.static_indexing_maps = static_indexing_maps
106
107    def to_yaml_custom_dict(self):
108        if self.static_indexing_maps is not None:
109            return dict(
110                static_indexing_maps=[
111                    _serialize_affine_map(m) for m in self.static_indexing_maps
112                ]
113            )
114        raise ValueError(
115            f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)"
116        )
117
118
119class LinalgStructuredOpConfig(YAMLObject):
120    """Configuration for metadata sufficient to construct a linalg named op."""
121
122    yaml_tag = "!LinalgStructuredOpConfig"
123
124    def __init__(
125        self,
126        comprehension: Comprehension,
127        domain: Sequence[DimDef],
128        registered_operands: Sequence[OperandDef],
129        context: Optional[_ir.Context] = None,
130    ):
131        self.context = context if context is not None else _ir.Context()
132        self.affine_state = AffineBuildState()
133        self.writes = list()  # type: List[Tuple[TensorUse, TensorExpression]]
134        self.operands = dict()  # type: Dict[OperandDef, OperandDefConfig]
135        self.uses = dict()  # type: Dict[TensorUse, TensorUseConfig]
136
137        # Compute the ordered set of writes and collect the tensor, capture, dims,
138        # and index uses.
139        collected_tensor_uses = set()
140        collected_scalar_uses = set()
141        collected_dim_uses = set()
142        collected_indices = set()
143        for write_use, read_use in zip(comprehension.definitions, comprehension.values):
144            self.writes.append((write_use, read_use))
145
146        for write_use, read_use in self.writes:
147            collected_tensor_uses.add(write_use)
148            read_use.collect_tensor_uses(collected_tensor_uses)
149            read_use.collect_scalar_uses(collected_scalar_uses)
150            read_use.collect_dim_uses(collected_dim_uses)
151            write_use.collect_dim_uses(collected_dim_uses)
152            read_use.collect_indices(collected_indices)
153
154        # Set domain to the sorted list of uses if no domain annotation is given.
155        if not domain:
156            domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
157
158        # Verify the domain dimensions match the used dimensions.
159        if len(domain) != len(collected_dim_uses) or any(
160            dim not in collected_dim_uses for dim in domain
161        ):
162            raise ValueError(
163                f"Expected the annotated domain dimensions {domain} to "
164                f"match the set of dimension used by the tensor "
165                f"comprehension {collected_dim_uses}"
166            )
167
168        # Instantiate the dimensions in the given order.
169        with self.context:
170            local_state = AffineBuildState(
171                global_state=self.affine_state, allow_new_symbols=False
172            )
173            for dim in domain:
174                dim.build(state=local_state)
175
176        # Collect all attribute definitions.
177        collected_attr_defs = list()
178        for operand in registered_operands:
179            if operand.is_attribute():
180                collected_attr_defs.append(operand)
181
182        # Collect all tensors with manual indexing annotation.
183        collected_index_defs = list()
184        for operand in registered_operands:
185            if operand.index_dims:
186                if any(dim not in collected_dim_uses for dim in operand.index_dims):
187                    raise ValueError(
188                        f"Expected all index dims {operand.index_dims} of "
189                        f"operand {operand.name} to have uses."
190                    )
191                collected_index_defs.append(operand)
192
193        # Collect the operand definitions of all tensor/scalar uses, attributes, and
194        # shape-only tensors.
195        all_operand_defs = list()
196        for use in collected_tensor_uses:
197            all_operand_defs.append(use.operand_def)
198        for use in collected_scalar_uses:
199            all_operand_defs.append(use.operand_def)
200        for definition in collected_attr_defs:
201            all_operand_defs.append(definition)
202        for definition in collected_index_defs:
203            all_operand_defs.append(definition)
204
205        # Add all operands in registration order to ensure the symbols are
206        # registered in the order they appear.
207        all_operand_defs = sorted(
208            all_operand_defs, key=lambda operand_def: operand_def.registered_index
209        )
210        for operand_def in all_operand_defs:
211            self.add_operand(operand_def)
212
213        # Add all shape-only tensor index_dim annotations and all tensor uses.
214        for definition in collected_index_defs:
215            self.add_indexed_operand(definition)
216        for use in collected_tensor_uses:
217            self.add_tensor_use(use)
218
219        # Normalize all shape and indexing maps now that full count of dims and
220        # symbols are known.
221        for cuse in self.uses.values():
222            cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
223        for definition in collected_index_defs:
224            self.operands[definition].indexing_map = self._normalize_affine_map(
225                self.operands[definition].indexing_map
226            )
227        for operand_config in self.operands.values():
228            if operand_config.shape_map:
229                operand_config.shape_map = self._normalize_affine_map(
230                    operand_config.shape_map, with_dims=False
231                )
232            if operand_config.index_attr_map:
233                operand_config.index_attr_map = self._normalize_affine_map(
234                    operand_config.index_attr_map, with_dims=False
235                )
236
237        # Now for each write use, propagate the indexing maps from the use to the
238        # tensor, ensuring that there are not conflicts.
239        for write_use, _ in self.writes:
240            write_tensor_config = self.operands[write_use.operand_def]
241            if write_tensor_config.indexing_map:
242                raise ValueError(
243                    f"Unexpected multi-write to a single tensor: {write_tensor_config}"
244                )
245            write_tensor_config.indexing_map = self.uses[write_use].indexing_map
246
247        # For each read use, propagate the indexing maps from the use to the
248        # tensor, ensuring that there are not conflicts.
249        for _, read_expr in self.writes:
250            read_uses = set()  # type: Set[TensorUse]
251            read_expr.collect_tensor_uses(read_uses)
252            for read_use in read_uses:
253                read_operand_config = self.operands[read_use.operand_def]
254                if (
255                    read_operand_config.indexing_map
256                    and read_operand_config.indexing_map
257                    != self.uses[read_use].indexing_map
258                ):
259                    raise ValueError(
260                        f"Unexpected multi-read of a tensor with different accesses:"
261                        f"{read_operand_config} vs {read_use}"
262                    )
263                read_operand_config.indexing_map = self.uses[read_use].indexing_map
264
265        # Set the indexing map of all scalar uses to the empty map.
266        for operand_config in self.operands.values():
267            if operand_config.operand_def.kind == OperandKind.SCALAR:
268                operand_config.indexing_map = self._get_scalar_map()
269
270        # Check all registered tensor and scalar operands have an indexing map.
271        for operand in registered_operands:
272            if operand.is_attribute():
273                continue
274            if not (operand in self.operands and self.operands[operand].indexing_map):
275                raise ValueError(
276                    f"Failed to compute an indexing map for operand " f"{operand.name}"
277                )
278
279        # Collect reduction dims and ensure all the same.
280        all_reduction_dims = set(comprehension.all_reduction_dims)
281        if len(all_reduction_dims) != 1:
282            raise ValueError(
283                f"All writes within a generic must have the same reduction "
284                f"dims. Got: {all_reduction_dims}"
285            )
286        self.reduction_dims = next(iter(all_reduction_dims))
287
288        # Check the index dimension exists and resolve.
289        for index in collected_indices:
290            if index.dim_def.dimname not in self.affine_state.all_dims:
291                raise ValueError(
292                    f"The dimension {index.dim_def.dimname} is not part of the "
293                    f"iteration domain {self.affine_state.all_dims}"
294                )
295            index.resolve_dimension_name(self.affine_state)
296
297        # Generate the scalar assignments (used to build a body).
298        self.assignments = [
299            ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
300            for write_use, read_expr in self.writes
301        ]
302
303    @property
304    def ordered_operands(self) -> Sequence[OperandDefConfig]:
305        return sorted(
306            self.operands.values(),
307            key=lambda operand: operand.operand_def.registered_index,
308        )
309
310    @property
311    def ordered_dims(self) -> Sequence[Tuple[str, int]]:
312        """Gets the ordered list of dim bindings (symbolic name, position).
313
314        TODO: The original parser relies on parse ordering to arrive at the
315        iterator types, but that ordering is not defined on the Python side, so
316        this may be ambiguous.
317        """
318        return list(self.affine_state.all_dims.items())
319
320    @property
321    def indexing_maps(self) -> Sequence[_ir.AffineMap]:
322        return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
323
324    @property
325    def iterator_types(self) -> Sequence[str]:
326        def get_type(symbolic_name, position):
327            for reduction_dim_expr in self.reduction_dims:
328                if reduction_dim_expr.dimname == symbolic_name:
329                    return "reduction"
330            return "parallel"
331
332        return [get_type(*dim) for dim in self.ordered_dims]
333
334    def add_operand(self, operand_def: OperandDef):
335        if operand_def in self.operands:
336            return
337        if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR):
338            self.operands[operand_def] = OperandDefConfig(operand_def)
339            return
340        with self.context:
341            local_state = AffineBuildState(
342                global_state=self.affine_state, allow_new_dims=False
343            )
344            exprs = []
345            for expr in operand_def.size_exprs:
346                exprs.append(expr.build(state=local_state))
347            assert local_state.local_dim_count == 0
348            affine_map = _ir.AffineMap.get(
349                dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs
350            )
351            if operand_def.kind == OperandKind.INDEX_ATTR:
352                self.operands[operand_def] = OperandDefConfig(
353                    operand_def, index_attr_map=affine_map
354                )
355            else:
356                self.operands[operand_def] = OperandDefConfig(
357                    operand_def, shape_map=affine_map
358                )
359
360    def add_indexed_operand(self, operand_def: OperandDef):
361        with self.context:
362            local_state = AffineBuildState(
363                global_state=self.affine_state, allow_new_symbols=False
364            )
365            exprs = []
366            for expr in operand_def.index_dims:
367                exprs.append(expr.build(state=local_state))
368            self.operands[operand_def].indexing_map = _ir.AffineMap.get(
369                dim_count=local_state.dim_count,
370                symbol_count=local_state.symbol_count,
371                exprs=exprs,
372            )
373
374    def add_tensor_use(self, tensor_use: TensorUse):
375        if tensor_use in self.uses:
376            return
377        with self.context:
378            local_state = AffineBuildState(
379                global_state=self.affine_state, allow_new_symbols=False
380            )
381            exprs = []
382            for expr in tensor_use.indices:
383                exprs.append(expr.build(state=local_state))
384            indexing_map = _ir.AffineMap.get(
385                dim_count=local_state.dim_count,
386                symbol_count=local_state.symbol_count,
387                exprs=exprs,
388            )
389
390            use_config = TensorUseConfig(tensor_use, indexing_map)
391            self.uses[tensor_use] = use_config
392
393    def _get_scalar_map(self) -> _ir.AffineMap:
394        """Create an empty affine map used to index a scalar."""
395        with self.context:
396            return _ir.AffineMap.get(
397                dim_count=self.affine_state.dim_count,
398                symbol_count=self.affine_state.symbol_count,
399                exprs=list(),
400            )
401
402    def _normalize_affine_map(
403        self, affine_map: _ir.AffineMap, with_dims: bool = True
404    ) -> _ir.AffineMap:
405        """Normalizes an indexing map to have the max known symbols and dims."""
406        with self.context:
407            return _ir.AffineMap.get(
408                dim_count=self.affine_state.dim_count if with_dims else 0,
409                symbol_count=self.affine_state.symbol_count,
410                exprs=list(affine_map.results),
411            )
412
413    def to_yaml_custom_dict(self):
414        self_dict = dict(args=self.ordered_operands)
415        # TODO: Refactor the hierarchy internally when supporting more
416        # than static (preserving this serialized form).
417        self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
418            static_indexing_maps=self.indexing_maps
419        )
420        self_dict["iterator_types"] = self.iterator_types
421        self_dict["assignments"] = self.assignments
422        return self_dict
423
424    def __repr__(self):
425        lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
426        lines.append("operands=[")
427        for def_config in self.ordered_operands:
428            lines.append(f"  {repr(def_config)}")
429        lines.append("], indexing_maps=[")
430        for m in self.indexing_maps:
431            lines.append(f"  {repr(m)}")
432        lines.append(f"], iterator_types=[")
433        for t in self.iterator_types:
434            lines.append(f"  {t}")
435        lines.append("])")
436        return "\n".join(lines)
437
438
439class LinalgOpConfig(YAMLObject):
440    """Container for any supported linalg op type.
441
442    This includes the concrete type by name for ease of parsing by systems
443    that ignore tags.
444    """
445
446    yaml_tag = "!LinalgOpConfig"
447
448    def __init__(
449        self,
450        metadata: OpMetadataDef,
451        *,
452        structured_op: Optional[LinalgStructuredOpConfig] = None,
453    ):
454        self.metadata = metadata
455        self.structured_op = structured_op
456
457    def to_yaml_custom_dict(self):
458        self_dict = dict(
459            metadata=self.metadata,
460        )
461        if self.structured_op:
462            self_dict["structured_op"] = self.structured_op
463        return self_dict
464
465    @staticmethod
466    def from_linalg_op_def(
467        op_def: LinalgOpDef, context: Optional[_ir.Context] = None
468    ) -> Sequence["LinalgOpConfig"]:
469        """Expands a LinalgOpDef into corresponding Linalg configured ops."""
470        # TODO: Many LinalgOpDef patterns need to expand to multiple generics.
471        assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
472        return [
473            LinalgOpConfig(
474                op_def.metadata,
475                structured_op=LinalgStructuredOpConfig(
476                    op_def.comprehensions[0],
477                    op_def.domain,
478                    op_def.registered_operands.values(),
479                    context,
480                ),
481            ),
482        ]
483
484    def __repr__(self):
485        return (
486            f"LinalgOpConfig(metadata={self.metadata},\n"
487            f"structured_op={self.structured_op})"
488        )
489