1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa. 6 7import numpy as np 8import ctypes 9 10try: 11 import ml_dtypes 12except ModuleNotFoundError: 13 # The third-party ml_dtypes provides some optional low precision data-types for NumPy. 14 ml_dtypes = None 15 16 17class C128(ctypes.Structure): 18 """A ctype representation for MLIR's Double Complex.""" 19 20 _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] 21 22 23class C64(ctypes.Structure): 24 """A ctype representation for MLIR's Float Complex.""" 25 26 _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] 27 28 29class F16(ctypes.Structure): 30 """A ctype representation for MLIR's Float16.""" 31 32 _fields_ = [("f16", ctypes.c_int16)] 33 34 35class BF16(ctypes.Structure): 36 """A ctype representation for MLIR's BFloat16.""" 37 38 _fields_ = [("bf16", ctypes.c_int16)] 39 40class F8E5M2(ctypes.Structure): 41 """A ctype representation for MLIR's Float8E5M2.""" 42 43 _fields_ = [("f8E5M2", ctypes.c_int8)] 44 45 46# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype 47def as_ctype(dtp): 48 """Converts dtype to ctype.""" 49 if dtp == np.dtype(np.complex128): 50 return C128 51 if dtp == np.dtype(np.complex64): 52 return C64 53 if dtp == np.dtype(np.float16): 54 return F16 55 if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: 56 return BF16 57 if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2: 58 return F8E5M2 59 return np.ctypeslib.as_ctypes_type(dtp) 60 61 62def to_numpy(array): 63 """Converts ctypes array back to numpy dtype array.""" 64 if array.dtype == C128: 65 return array.view("complex128") 66 if array.dtype == C64: 67 return array.view("complex64") 68 if array.dtype == F16: 69 return array.view("float16") 70 assert not ( 71 array.dtype == BF16 and ml_dtypes is None 72 ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" 73 if array.dtype == BF16: 74 return array.view("bfloat16") 75 assert not ( 76 array.dtype == F8E5M2 and ml_dtypes is None 77 ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" 78 if array.dtype == F8E5M2: 79 return array.view("float8_e5m2") 80 return array 81 82 83def make_nd_memref_descriptor(rank, dtype): 84 class MemRefDescriptor(ctypes.Structure): 85 """Builds an empty descriptor for the given rank/dtype, where rank>0.""" 86 87 _fields_ = [ 88 ("allocated", ctypes.c_longlong), 89 ("aligned", ctypes.POINTER(dtype)), 90 ("offset", ctypes.c_longlong), 91 ("shape", ctypes.c_longlong * rank), 92 ("strides", ctypes.c_longlong * rank), 93 ] 94 95 return MemRefDescriptor 96 97 98def make_zero_d_memref_descriptor(dtype): 99 class MemRefDescriptor(ctypes.Structure): 100 """Builds an empty descriptor for the given dtype, where rank=0.""" 101 102 _fields_ = [ 103 ("allocated", ctypes.c_longlong), 104 ("aligned", ctypes.POINTER(dtype)), 105 ("offset", ctypes.c_longlong), 106 ] 107 108 return MemRefDescriptor 109 110 111class UnrankedMemRefDescriptor(ctypes.Structure): 112 """Creates a ctype struct for memref descriptor""" 113 114 _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] 115 116 117def get_ranked_memref_descriptor(nparray): 118 """Returns a ranked memref descriptor for the given numpy array.""" 119 ctp = as_ctype(nparray.dtype) 120 if nparray.ndim == 0: 121 x = make_zero_d_memref_descriptor(ctp)() 122 x.allocated = nparray.ctypes.data 123 x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) 124 x.offset = ctypes.c_longlong(0) 125 return x 126 127 x = make_nd_memref_descriptor(nparray.ndim, ctp)() 128 x.allocated = nparray.ctypes.data 129 x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) 130 x.offset = ctypes.c_longlong(0) 131 x.shape = nparray.ctypes.shape 132 133 # Numpy uses byte quantities to express strides, MLIR OTOH uses the 134 # torch abstraction which specifies strides in terms of elements. 135 strides_ctype_t = ctypes.c_longlong * nparray.ndim 136 x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) 137 return x 138 139 140def get_unranked_memref_descriptor(nparray): 141 """Returns a generic/unranked memref descriptor for the given numpy array.""" 142 d = UnrankedMemRefDescriptor() 143 d.rank = nparray.ndim 144 x = get_ranked_memref_descriptor(nparray) 145 d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) 146 return d 147 148 149def move_aligned_ptr_by_offset(aligned_ptr, offset): 150 """Moves the supplied ctypes pointer ahead by `offset` elements.""" 151 aligned_addr = ctypes.addressof(aligned_ptr.contents) 152 elem_size = ctypes.sizeof(aligned_ptr.contents) 153 shift = offset * elem_size 154 content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr)) 155 return content_ptr 156 157 158def unranked_memref_to_numpy(unranked_memref, np_dtype): 159 """Converts unranked memrefs to numpy arrays.""" 160 ctp = as_ctype(np_dtype) 161 descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) 162 val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) 163 content_ptr = move_aligned_ptr_by_offset(val[0].aligned, val[0].offset) 164 np_arr = np.ctypeslib.as_array(content_ptr, shape=val[0].shape) 165 strided_arr = np.lib.stride_tricks.as_strided( 166 np_arr, 167 np.ctypeslib.as_array(val[0].shape), 168 np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, 169 ) 170 return to_numpy(strided_arr) 171 172 173def ranked_memref_to_numpy(ranked_memref): 174 """Converts ranked memrefs to numpy arrays.""" 175 content_ptr = move_aligned_ptr_by_offset( 176 ranked_memref[0].aligned, ranked_memref[0].offset 177 ) 178 np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape) 179 strided_arr = np.lib.stride_tricks.as_strided( 180 np_arr, 181 np.ctypeslib.as_array(ranked_memref[0].shape), 182 np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, 183 ) 184 return to_numpy(strided_arr) 185