xref: /llvm-project/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py (revision 0e34dbb4f452013eab89a0a8f04a436ff6c408d4)
1# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \
2# RUN:   %PYTHON %s | FileCheck %s
3
4import ctypes
5import errno
6import itertools
7import os
8import sys
9
10from typing import List, Callable
11
12import numpy as np
13
14from mlir import ir
15from mlir import runtime as rt
16
17from mlir.dialects import bufferization
18from mlir.dialects import builtin
19from mlir.dialects import func
20from mlir.dialects import sparse_tensor as st
21
22_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
23sys.path.append(_SCRIPT_PATH)
24from tools import sparsifier
25
26# ===----------------------------------------------------------------------=== #
27
28
29class TypeConverter:
30    """Converter between NumPy types and MLIR types."""
31
32    def __init__(self, context: ir.Context):
33        # Note 1: these are numpy "scalar types" (i.e., the values of
34        # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
35        #
36        # Note 2: we must construct the MLIR types in the same context as the
37        # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
38        # otherwise, those methods will raise a KeyError.
39        types_list = [
40            (np.float64, ir.F64Type.get(context=context)),
41            (np.float32, ir.F32Type.get(context=context)),
42            (np.int64, ir.IntegerType.get_signless(64, context=context)),
43            (np.int32, ir.IntegerType.get_signless(32, context=context)),
44            (np.int16, ir.IntegerType.get_signless(16, context=context)),
45            (np.int8, ir.IntegerType.get_signless(8, context=context)),
46        ]
47        self._sc2ir = dict(types_list)
48        self._ir2sc = dict(((ir, sc) for sc, ir in types_list))
49
50    def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
51        """Returns the MLIR equivalent of a NumPy dtype."""
52        try:
53            return self.sctype_to_irtype(dtype.type)
54        except KeyError as e:
55            raise KeyError(f"Unknown dtype: {dtype}") from e
56
57    def sctype_to_irtype(self, sctype) -> ir.Type:
58        """Returns the MLIR equivalent of a NumPy scalar type."""
59        if sctype in self._sc2ir:
60            return self._sc2ir[sctype]
61        else:
62            raise KeyError(f"Unknown sctype: {sctype}")
63
64    def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
65        """Returns the NumPy dtype equivalent of an MLIR type."""
66        return np.dtype(self.irtype_to_sctype(tp))
67
68    def irtype_to_sctype(self, tp: ir.Type):
69        """Returns the NumPy scalar-type equivalent of an MLIR type."""
70        if tp in self._ir2sc:
71            return self._ir2sc[tp]
72        else:
73            raise KeyError(f"Unknown ir.Type: {tp}")
74
75    def get_RankedTensorType_of_nparray(
76        self, nparray: np.ndarray
77    ) -> ir.RankedTensorType:
78        """Returns the ir.RankedTensorType of a NumPy array.  Note that NumPy
79        arrays can only be converted to/from dense tensors, not sparse tensors."""
80        return ir.RankedTensorType.get(
81            nparray.shape, self.dtype_to_irtype(nparray.dtype)
82        )
83
84
85# ===----------------------------------------------------------------------=== #
86
87
88class StressTest:
89    def __init__(self, tyconv: TypeConverter):
90        self._tyconv = tyconv
91        self._roundtripTp = None
92        self._module = None
93        self._engine = None
94
95    def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
96        assert self._roundtripTp is not None, "StressTest: uninitialized roundtrip type"
97        if tp != self._roundtripTp:
98            raise AssertionError(
99                f"Type is not equal to the roundtrip type.\n"
100                f"\tExpected: {self._roundtripTp}\n"
101                f"\tFound:    {tp}\n"
102            )
103
104    def build(self, types: List[ir.Type]):
105        """Builds the ir.Module.  The module has only the @main function,
106        which will convert the input through the list of types and then back
107        to the initial type.  The roundtrip type must be a dense tensor."""
108        assert self._module is None, "StressTest: must not call build() repeatedly"
109        self._module = ir.Module.create()
110        with ir.InsertionPoint(self._module.body):
111            tp0 = types.pop(0)
112            self._roundtripTp = tp0
113            types.append(tp0)
114            funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
115            funcOp = func.FuncOp(name="main", type=funcTp)
116            funcOp.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
117            with ir.InsertionPoint(funcOp.add_entry_block()):
118                arg0 = funcOp.entry_block.arguments[0]
119                self._assertEqualsRoundtripTp(arg0.type)
120                v = st.ConvertOp(types.pop(0), arg0)
121                for tp in types:
122                    w = st.ConvertOp(tp, v)
123                    # Release intermediate tensors before they fall out of scope.
124                    bufferization.DeallocTensorOp(v.result)
125                    v = w
126                self._assertEqualsRoundtripTp(v.result.type)
127                func.ReturnOp(v)
128        return self
129
130    def writeTo(self, filename):
131        """Write the ir.Module to the given file.  If the file already exists,
132        then raises an error.  If the filename is None, then is a no-op."""
133        assert (
134            self._module is not None
135        ), "StressTest: must call build() before writeTo()"
136        if filename is None:
137            # Silent no-op, for convenience.
138            return self
139        if os.path.exists(filename):
140            raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
141        with open(filename, "w") as f:
142            f.write(str(self._module))
143        return self
144
145    def compile(self, compiler):
146        """Compile the ir.Module."""
147        assert (
148            self._module is not None
149        ), "StressTest: must call build() before compile()"
150        assert self._engine is None, "StressTest: must not call compile() repeatedly"
151        self._engine = compiler.compile_and_jit(self._module)
152        return self
153
154    def run(self, np_arg0: np.ndarray) -> np.ndarray:
155        """Runs the test on the given numpy array, and returns the resulting
156        numpy array."""
157        assert self._engine is not None, "StressTest: must call compile() before run()"
158        self._assertEqualsRoundtripTp(
159            self._tyconv.get_RankedTensorType_of_nparray(np_arg0)
160        )
161        np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
162        self._assertEqualsRoundtripTp(
163            self._tyconv.get_RankedTensorType_of_nparray(np_out)
164        )
165        mem_arg0 = ctypes.pointer(
166            ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0))
167        )
168        mem_out = ctypes.pointer(
169            ctypes.pointer(rt.get_ranked_memref_descriptor(np_out))
170        )
171        self._engine.invoke("main", mem_out, mem_arg0)
172        return rt.ranked_memref_to_numpy(mem_out[0])
173
174
175# ===----------------------------------------------------------------------=== #
176
177
178def main():
179    """
180    USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
181
182    The environment variable SUPPORT_LIB must be set to point to the
183    libmlir_c_runner_utils shared library.  There are two optional
184    arguments, for debugging purposes.  The first argument specifies where
185    to write out the raw/generated ir.Module.  The second argument specifies
186    where to write out the compiled version of that ir.Module.
187    """
188    support_lib = os.getenv("SUPPORT_LIB")
189    assert support_lib is not None, "SUPPORT_LIB is undefined"
190    if not os.path.exists(support_lib):
191        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
192
193    # CHECK-LABEL: TEST: test_stress
194    print("\nTEST: test_stress")
195    with ir.Context() as ctx, ir.Location.unknown():
196        sparsification_options = f"parallelization-strategy=none "
197        compiler = sparsifier.Sparsifier(
198            extras="",
199            options=sparsification_options,
200            opt_level=0,
201            shared_libs=[support_lib],
202        )
203        f64 = ir.F64Type.get()
204        # Be careful about increasing this because
205        #     len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
206        shape = range(2, 3)
207        rank = len(shape)
208        # All combinations.
209        dense_lvl = st.EncodingAttr.build_level_type(st.LevelFormat.dense)
210        sparse_lvl = st.EncodingAttr.build_level_type(st.LevelFormat.compressed)
211        levels = list(
212            itertools.product(*itertools.repeat([dense_lvl, sparse_lvl], rank))
213        )
214        # All permutations.
215        orderings = list(
216            map(ir.AffineMap.get_permutation, itertools.permutations(range(rank)))
217        )
218        bitwidths = [0]
219        # The first type must be a dense tensor for numpy conversion to work.
220        types = [ir.RankedTensorType.get(shape, f64)]
221        for level in levels:
222            for ordering in orderings:
223                for pwidth in bitwidths:
224                    for iwidth in bitwidths:
225                        attr = st.EncodingAttr.get(
226                            level, ordering, None, pwidth, iwidth
227                        )
228                        types.append(ir.RankedTensorType.get(shape, f64, attr))
229        #
230        # For exhaustiveness we should have one or more StressTest, such
231        # that their paths cover all 2*n*(n-1) directed pairwise combinations
232        # of the `types` set.  However, since n is already superexponential,
233        # such exhaustiveness would be prohibitive for a test that runs on
234        # every commit.  So for now we'll just pick one particular path that
235        # at least hits all n elements of the `types` set.
236        #
237        tyconv = TypeConverter(ctx)
238        size = 1
239        for d in shape:
240            size *= d
241        np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
242        np_out = (
243            StressTest(tyconv)
244            .build(types)
245            .writeTo(sys.argv[1] if len(sys.argv) > 1 else None)
246            .compile(compiler)
247            .writeTo(sys.argv[2] if len(sys.argv) > 2 else None)
248            .run(np_arg0)
249        )
250        # CHECK: Passed
251        if np.allclose(np_out, np_arg0):
252            print("Passed")
253        else:
254            sys.exit("FAILURE")
255
256
257if __name__ == "__main__":
258    main()
259