1from enum import Enum 2import functools, sys, ctypes, os, errno 3import numpy as np 4from functools import partialmethod 5from mlir import ir 6from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm 7from mlir.extras import types as T 8from mlir import runtime as rt 9from tools import nvgpucompiler 10 11MLIR_DYNAMIC = -9223372036854775808 12 13 14def const(value: int, ty=None): 15 ty = T.index() if ty is None else ty 16 if isinstance(value, ir.Value) and ( 17 value.type.isinstance(value.type) or T.bool().isinstance(value.type) 18 ): 19 return value 20 return arith.constant(ty, value) 21 22 23def get_type_size(ty): 24 if ir.MemRefType.isinstance(ty): 25 size = get_type_size(ty.element_type) 26 for sz in ty.shape: 27 size *= sz 28 return size 29 if ir.FloatType.isinstance(ty): 30 return ir.FloatType(ty).width // 8 31 if ir.IntegerType.isinstance(ty): 32 return ir.IntegerType(ty).width // 8 33 raise NotImplementedError(ty) 34 35 36def get_mlir_func_obj_ty(inputArgs): 37 args = [] 38 c_int_p = ctypes.c_int * 1 39 c_float_p = ctypes.c_float * 1 40 c_bool_p = ctypes.c_bool * 1 41 for arg in inputArgs: 42 if isinstance(arg, bool): 43 args.append(c_bool_p(arg)) 44 elif isinstance(arg, int): 45 args.append(c_int_p(arg)) 46 elif isinstance(arg, float): 47 args.append(c_float_p(arg)) 48 elif isinstance(arg, np.ndarray): 49 args.append( 50 ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg))) 51 ) 52 else: 53 raise NotImplementedError(arg) 54 return args 55 56 57class Mbarriers: 58 def __init__(self, number_of_barriers=1): 59 self.mbar_ty = ir.Type.parse( 60 "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = " 61 + str(number_of_barriers) 62 + ">" 63 ) 64 self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty) 65 self.number_of_barriers = number_of_barriers 66 67 def __getitem__(self, key): 68 self.id_op = const(key) 69 return self 70 71 def init(self, count: int, predicate=None): 72 count_op = const(count) 73 if predicate is None: 74 nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op) 75 else: 76 nvgpu.mbarrier_init( 77 self.mbar_group_op, count_op, self.id_op, predicate=predicate 78 ) 79 80 def arrive(self, txcount: int = 0, predicate=None): 81 if txcount != 0: 82 txcount_op = const(txcount) 83 nvgpu.mbarrier_arrive_expect_tx( 84 self.mbar_group_op, txcount_op, self.id_op, predicate=predicate 85 ) 86 else: 87 nvgpu.mbarrier_arrive( 88 ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op 89 ) 90 91 def try_wait(self, phase: bool = False, ticks: int = 10000000): 92 ticks_op = const(ticks) 93 phase_op = const(phase, T.bool()) 94 nvgpu.MBarrierTryWaitParityOp( 95 self.mbar_group_op, 96 phase_op, 97 ticks_op, 98 mbarId=self.id_op, 99 ) 100 101 102class TMA: 103 """A class that builds a TMA descriptor.""" 104 105 def __init__( 106 self, 107 tma_box_shape, 108 memref_ty, 109 swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE, 110 l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, 111 oob=nvgpu.TensorMapOOBKind.OOB_ZERO, 112 interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, 113 ): 114 self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind 115 self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind 116 self.oob = oob # mlir.nvgpu.TensorMapOOBKind 117 self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind 118 self.tma_box_shape = tma_box_shape 119 self.memref_ty = memref_ty # MemRefType 120 self.tma_memref = ir.MemRefType.get(tma_box_shape, memref_ty.element_type) 121 122 @property 123 def tensormap_descriptor_ty(self): 124 """Returns a tensormap descriptor type.""" 125 tensorMemrefType = ir.MemRefType.get( 126 self.tma_box_shape, 127 self.memref_ty.element_type, 128 memory_space=ir.Attribute.parse("3"), 129 ) 130 return nvgpu.TensorMapDescriptorType.get( 131 tensorMemrefType, 132 self.swizzle, 133 self.l2promo, 134 self.oob, 135 self.interleave, 136 ) 137 138 def create_descriptor(self, device_ptr): 139 tma_descriptor_ty = self.tensormap_descriptor_ty 140 device_unranked_memref = memref.CastOp( 141 ir.UnrankedMemRefType.get( 142 self.memref_ty.element_type, self.memref_ty.memory_space 143 ), 144 device_ptr, 145 ) 146 self.tma_descriptor = nvgpu.TmaCreateDescriptorOp( 147 tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape) 148 ) 149 return self.tma_descriptor.result 150 151 def prefetch(self, predicate=None): 152 nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate) 153 154 def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None): 155 nvgpu.TmaAsyncLoadOp( 156 dest, 157 mbarrier.mbar_group_op, 158 self.tma_descriptor, 159 coordinates=map(const, coords), 160 mbarId=mbarrier.id_op, 161 predicate=predicate, 162 ) 163 164 165WARP_GROUP_SIZE = 128 # Number of threads in a warpgroup 166 167 168class Warpgroup: 169 def __init__(self, primary_thread, register_size): 170 assert (primary_thread % WARP_GROUP_SIZE) == 0 171 tidx = gpu.thread_id(gpu.Dimension.x) 172 self.primary_thread = primary_thread 173 self.register_size = register_size 174 self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0 175 self.wg_id = tidx / WARP_GROUP_SIZE 176 self.is_me = self.wg_id == (primary_thread // WARP_GROUP_SIZE) 177 178 def __enter__(self): 179 if_op = scf.IfOp(self.is_me) 180 self.ipoint_op = ir.InsertionPoint(if_op.then_block) 181 self.ipoint_op.__enter__() 182 if self.register_size < 64: 183 nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease) 184 else: 185 nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase) 186 187 def __exit__(self, exc_type, exc_value, traceback): 188 scf.yield_([]) 189 self.ipoint_op.__exit__(exc_type, exc_value, traceback) 190 return True 191 192 193class WGMMAType(Enum): 194 Accumulator = 1 195 Descriptor = 2 196 197 198class WGMMAMatrix: 199 def __init__( 200 self, 201 matrix_type: WGMMAType, 202 shape: list = None, 203 desc: TMA = None, 204 smem=None, 205 ty=None, 206 acc_op=None, 207 ): 208 if acc_op is None: 209 self.M = shape[0] 210 self.N = shape[1] 211 self.ty = ty 212 self.matrix_type = matrix_type 213 self.desc = desc 214 self.smem = smem 215 if matrix_type is WGMMAType.Accumulator: 216 self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty) 217 elif acc_op: 218 self.acc_op = acc_op 219 self.matrix_type = WGMMAType.Accumulator 220 221 @property 222 def acc_ty(self): 223 parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>" 224 return ir.Type.parse(parse_str) 225 226 @property 227 def wgmma_ty(self): 228 parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>" 229 return ir.Type.parse(parse_str) 230 231 def store_accumulator(self, dest): 232 assert self.matrix_type == WGMMAType.Accumulator 233 nvgpu.warpgroup_mma_store(self.acc_op, dest) 234 235 def update_smem(self, smem): 236 self.smem = smem 237 238 def update_accumulator(self, acc_op): 239 self.acc_op = acc_op 240 241 def __matmul__(self, rhs): 242 lhs = nvgpu.warpgroup_generate_descriptor( 243 self.wgmma_ty, self.smem, self.desc.tma_descriptor 244 ) 245 rhs = nvgpu.warpgroup_generate_descriptor( 246 rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor 247 ) 248 return [lhs, rhs] 249 250 def __iadd__(self, matmulResult): 251 lhs = matmulResult[0] 252 rhs = matmulResult[1] 253 acc_op = nvgpu.WarpgroupMmaOp( 254 self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True 255 ) 256 return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op) 257 258 259def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0): 260 smem_space_str = "#gpu.address_space<workgroup>" 261 smem_space = ir.Attribute.parse(smem_space_str) 262 dynamic_smem = gpu.dynamic_shared_memory( 263 ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space) 264 ) 265 if shape is None: 266 return dynamic_smem 267 memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space) 268 return memref.view( 269 ir.MemRefType.get( 270 memref_ty.shape, memref_ty.element_type, memory_space=smem_space 271 ), 272 dynamic_smem, 273 const(offset), 274 [], 275 ) 276 277 278def get_mlir_ty(arg): 279 def get_mlir_ty_from_np(dtype): 280 if dtype == np.float16: 281 return T.f16() 282 if dtype == np.float32: 283 return T.f32() 284 if dtype == np.float64: 285 return T.f64() 286 if dtype == np.int32: 287 return T.i32() 288 if dtype == np.int64: 289 return T.i64() 290 raise NotImplementedError(dtype) 291 292 if isinstance(arg, bool): 293 return T.bool() 294 elif isinstance(arg, int): 295 return T.index() 296 elif isinstance(arg, float): 297 return T.f32() 298 elif isinstance(arg, np.ndarray): 299 descriptor = rt.get_ranked_memref_descriptor(arg) 300 dtype = get_mlir_ty_from_np(arg.dtype) 301 shape = descriptor.shape 302 return memref.MemRefType.get(shape, dtype) 303 raise NotImplementedError(arg) 304 305 306class NVDSL: 307 @staticmethod 308 def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0): 309 def decorator(func): 310 @functools.wraps(func) 311 def wrapper(*args, **kwargs): 312 launch_op = gpu.LaunchOp( 313 None, 314 [], 315 *map(const, grid), 316 *map(const, block), 317 dynamicSharedMemorySize=arith.constant(T.i32(), smem), 318 ) 319 launch_op.body.blocks.append(*([T.index()] * 12)) 320 with ir.InsertionPoint(launch_op.body.blocks[0]): 321 result = func(*args, **kwargs) 322 gpu.terminator() 323 return result 324 325 return wrapper 326 327 return decorator 328 329 @staticmethod 330 def mlir_func(funcBody): 331 @functools.wraps(funcBody) 332 def wrapper(*args, **kwargs): 333 function_name = funcBody.__name__ 334 335 def saveIR(module): 336 """Save generated IR""" 337 if True: # self.saveIR: 338 # print(mlir_nvgpu_module) 339 original_stdout = sys.stdout 340 with open("nvdsl.mlir", "w") as f: 341 sys.stdout = f 342 print(module) 343 sys.stdout = original_stdout 344 345 def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue": 346 """Generate MLIR's Arith dialects binary operations.""" 347 rhs = const(rhs) 348 if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): 349 op += "F" 350 if op.startswith("Cmp"): 351 predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[ 352 predAtt 353 ] 354 elif arith._is_integer_like_type( 355 lhs.type 356 ) and arith._is_integer_like_type(lhs.type): 357 if op == "Div" or op == "Rem": 358 op += "U" 359 op += "I" 360 if op.startswith("Cmp"): 361 predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[ 362 predAtt 363 ] 364 else: 365 raise NotImplementedError( 366 f"Unsupported '{op}' operands: {lhs}, {rhs}" 367 ) 368 369 if op.startswith("Cmp"): 370 op = getattr(arith, f"{op}Op") 371 372 return op(predicateAttr, lhs, rhs).result 373 else: 374 op = getattr(arith, f"{op}Op") 375 return op(lhs, rhs).result 376 377 @ir.register_value_caster(ir.IndexType.static_typeid) 378 @ir.register_value_caster(ir.F32Type.static_typeid) 379 @ir.register_value_caster(ir.F16Type.static_typeid) 380 @ir.register_value_caster(ir.F64Type.static_typeid) 381 @ir.register_value_caster(ir.IntegerType.static_typeid) 382 class ArithValue(ir.Value): 383 """Overloads operators for MLIR's Arith dialects binary operations.""" 384 385 def __init__(self, v): 386 super().__init__(v) 387 388 __add__ = partialmethod(_binary_op, op="Add") 389 __sub__ = partialmethod(_binary_op, op="Sub") 390 __mul__ = partialmethod(_binary_op, op="Mul") 391 __truediv__ = partialmethod(_binary_op, op="Div") 392 __mod__ = partialmethod(_binary_op, op="Rem") 393 __xor__ = partialmethod(_binary_op, op="XOr") 394 __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult") 395 __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule") 396 __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq") 397 __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne") 398 __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt") 399 __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge") 400 __and__ = partialmethod(_binary_op, op="And") 401 __or__ = partialmethod(_binary_op, op="Or") 402 403 def __str__(self): 404 return ( 405 super() 406 .__str__() 407 .replace(ir.Value.__name__, ArithValue.__name__) 408 ) 409 410 # Generate MLIR Context and start generating IR 411 with ir.Context(), ir.Location.unknown(): 412 types = [] 413 for arg in args: 414 types.append(get_mlir_ty(arg)) 415 416 # Build IR 417 module = ir.Module.create() 418 with ir.InsertionPoint(module.body): 419 fop = func.FuncOp(function_name, (types, [])) 420 fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() 421 with ir.InsertionPoint(fop.add_entry_block()): 422 fargs = [] 423 for i, a in enumerate(types): 424 fargs.append(fop.arguments[i]) 425 426 # Call user function body 427 result = funcBody(*fargs, **kwargs) 428 func.ReturnOp([]) 429 430 # Save IR in a file 431 # saveIR(module) 432 433 # Verify the module 434 module.operation.verify() 435 436 # Compile and JIT MLIR module 437 options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" 438 support_lib = os.getenv("SUPPORT_LIB") 439 if not os.path.exists(support_lib): 440 raise FileNotFoundError( 441 errno.ENOENT, os.strerror(errno.ENOENT), support_lib 442 ) 443 compiler = nvgpucompiler.NvgpuCompiler( 444 options, opt_level=3, shared_libs=[support_lib] 445 ) 446 engine = compiler.compile_and_jit(module) 447 448 # Convert input arguments to MLIR arguments 449 newArgs = get_mlir_func_obj_ty(args) 450 451 # Run the compiled program 452 engine.invoke(function_name, *newArgs) 453 454 return result 455 456 return wrapper 457