xref: /llvm-project/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4"""YAML serialization is routed through here to centralize common logic."""
5
6import sys
7
8try:
9    import yaml
10except ModuleNotFoundError as e:
11    raise ModuleNotFoundError(
12        f"This tool requires PyYAML but it was not installed. "
13        f"Recommend: {sys.executable} -m pip install PyYAML"
14    ) from e
15
16__all__ = [
17    "yaml_dump",
18    "yaml_dump_all",
19    "YAMLObject",
20]
21
22
23class YAMLObject(yaml.YAMLObject):
24    @classmethod
25    def to_yaml(cls, dumper, self):
26        """Default to a custom dictionary mapping."""
27        return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
28
29    def to_yaml_custom_dict(self):
30        raise NotImplementedError()
31
32    def as_linalg_yaml(self):
33        return yaml_dump(self)
34
35
36def multiline_str_representer(dumper, data):
37    if len(data.splitlines()) > 1:
38        return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
39    else:
40        return dumper.represent_scalar("tag:yaml.org,2002:str", data)
41
42
43yaml.add_representer(str, multiline_str_representer)
44
45
46def yaml_dump(data, sort_keys=False, **kwargs):
47    return yaml.dump(data, sort_keys=sort_keys, **kwargs)
48
49
50def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
51    return yaml.dump_all(
52        data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
53    )
54