xref: /llvm-project/mlir/python/mlir/runtime/np_to_memref.py (revision c8cac33ad23acc671a0a7390a5254b9f6e848138)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5#  This file contains functions to convert between Memrefs and NumPy arrays and vice-versa.
6
7import numpy as np
8import ctypes
9
10try:
11    import ml_dtypes
12except ModuleNotFoundError:
13    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
14    ml_dtypes = None
15
16
17class C128(ctypes.Structure):
18    """A ctype representation for MLIR's Double Complex."""
19
20    _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
21
22
23class C64(ctypes.Structure):
24    """A ctype representation for MLIR's Float Complex."""
25
26    _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
27
28
29class F16(ctypes.Structure):
30    """A ctype representation for MLIR's Float16."""
31
32    _fields_ = [("f16", ctypes.c_int16)]
33
34
35class BF16(ctypes.Structure):
36    """A ctype representation for MLIR's BFloat16."""
37
38    _fields_ = [("bf16", ctypes.c_int16)]
39
40class F8E5M2(ctypes.Structure):
41    """A ctype representation for MLIR's Float8E5M2."""
42
43    _fields_ = [("f8E5M2", ctypes.c_int8)]
44
45
46# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
47def as_ctype(dtp):
48    """Converts dtype to ctype."""
49    if dtp == np.dtype(np.complex128):
50        return C128
51    if dtp == np.dtype(np.complex64):
52        return C64
53    if dtp == np.dtype(np.float16):
54        return F16
55    if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
56        return BF16
57    if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
58        return F8E5M2
59    return np.ctypeslib.as_ctypes_type(dtp)
60
61
62def to_numpy(array):
63    """Converts ctypes array back to numpy dtype array."""
64    if array.dtype == C128:
65        return array.view("complex128")
66    if array.dtype == C64:
67        return array.view("complex64")
68    if array.dtype == F16:
69        return array.view("float16")
70    assert not (
71        array.dtype == BF16 and ml_dtypes is None
72    ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
73    if array.dtype == BF16:
74        return array.view("bfloat16")
75    assert not (
76        array.dtype == F8E5M2 and ml_dtypes is None
77    ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
78    if array.dtype == F8E5M2:
79        return array.view("float8_e5m2")
80    return array
81
82
83def make_nd_memref_descriptor(rank, dtype):
84    class MemRefDescriptor(ctypes.Structure):
85        """Builds an empty descriptor for the given rank/dtype, where rank>0."""
86
87        _fields_ = [
88            ("allocated", ctypes.c_longlong),
89            ("aligned", ctypes.POINTER(dtype)),
90            ("offset", ctypes.c_longlong),
91            ("shape", ctypes.c_longlong * rank),
92            ("strides", ctypes.c_longlong * rank),
93        ]
94
95    return MemRefDescriptor
96
97
98def make_zero_d_memref_descriptor(dtype):
99    class MemRefDescriptor(ctypes.Structure):
100        """Builds an empty descriptor for the given dtype, where rank=0."""
101
102        _fields_ = [
103            ("allocated", ctypes.c_longlong),
104            ("aligned", ctypes.POINTER(dtype)),
105            ("offset", ctypes.c_longlong),
106        ]
107
108    return MemRefDescriptor
109
110
111class UnrankedMemRefDescriptor(ctypes.Structure):
112    """Creates a ctype struct for memref descriptor"""
113
114    _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
115
116
117def get_ranked_memref_descriptor(nparray):
118    """Returns a ranked memref descriptor for the given numpy array."""
119    ctp = as_ctype(nparray.dtype)
120    if nparray.ndim == 0:
121        x = make_zero_d_memref_descriptor(ctp)()
122        x.allocated = nparray.ctypes.data
123        x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
124        x.offset = ctypes.c_longlong(0)
125        return x
126
127    x = make_nd_memref_descriptor(nparray.ndim, ctp)()
128    x.allocated = nparray.ctypes.data
129    x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
130    x.offset = ctypes.c_longlong(0)
131    x.shape = nparray.ctypes.shape
132
133    # Numpy uses byte quantities to express strides, MLIR OTOH uses the
134    # torch abstraction which specifies strides in terms of elements.
135    strides_ctype_t = ctypes.c_longlong * nparray.ndim
136    x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
137    return x
138
139
140def get_unranked_memref_descriptor(nparray):
141    """Returns a generic/unranked memref descriptor for the given numpy array."""
142    d = UnrankedMemRefDescriptor()
143    d.rank = nparray.ndim
144    x = get_ranked_memref_descriptor(nparray)
145    d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
146    return d
147
148
149def move_aligned_ptr_by_offset(aligned_ptr, offset):
150    """Moves the supplied ctypes pointer ahead by `offset` elements."""
151    aligned_addr = ctypes.addressof(aligned_ptr.contents)
152    elem_size = ctypes.sizeof(aligned_ptr.contents)
153    shift = offset * elem_size
154    content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr))
155    return content_ptr
156
157
158def unranked_memref_to_numpy(unranked_memref, np_dtype):
159    """Converts unranked memrefs to numpy arrays."""
160    ctp = as_ctype(np_dtype)
161    descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
162    val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
163    content_ptr = move_aligned_ptr_by_offset(val[0].aligned, val[0].offset)
164    np_arr = np.ctypeslib.as_array(content_ptr, shape=val[0].shape)
165    strided_arr = np.lib.stride_tricks.as_strided(
166        np_arr,
167        np.ctypeslib.as_array(val[0].shape),
168        np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
169    )
170    return to_numpy(strided_arr)
171
172
173def ranked_memref_to_numpy(ranked_memref):
174    """Converts ranked memrefs to numpy arrays."""
175    content_ptr = move_aligned_ptr_by_offset(
176        ranked_memref[0].aligned, ranked_memref[0].offset
177    )
178    np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape)
179    strided_arr = np.lib.stride_tricks.as_strided(
180        np_arr,
181        np.ctypeslib.as_array(ranked_memref[0].shape),
182        np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
183    )
184    return to_numpy(strided_arr)
185