xref: /llvm-project/mlir/test/Examples/NVGPU/tools/nvdsl.py (revision 51752ed0dd737f12014a89dec67d25494083153d)
1from enum import Enum
2import functools, sys, ctypes, os, errno
3import numpy as np
4from functools import partialmethod
5from mlir import ir
6from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm
7from mlir.extras import types as T
8from mlir import runtime as rt
9from tools import nvgpucompiler
10
11MLIR_DYNAMIC = -9223372036854775808
12
13
14def const(value: int, ty=None):
15    ty = T.index() if ty is None else ty
16    if isinstance(value, ir.Value) and (
17        value.type.isinstance(value.type) or T.bool().isinstance(value.type)
18    ):
19        return value
20    return arith.constant(ty, value)
21
22
23def get_type_size(ty):
24    if ir.MemRefType.isinstance(ty):
25        size = get_type_size(ty.element_type)
26        for sz in ty.shape:
27            size *= sz
28        return size
29    if ir.FloatType.isinstance(ty):
30        return ir.FloatType(ty).width // 8
31    if ir.IntegerType.isinstance(ty):
32        return ir.IntegerType(ty).width // 8
33    raise NotImplementedError(ty)
34
35
36def get_mlir_func_obj_ty(inputArgs):
37    args = []
38    c_int_p = ctypes.c_int * 1
39    c_float_p = ctypes.c_float * 1
40    c_bool_p = ctypes.c_bool * 1
41    for arg in inputArgs:
42        if isinstance(arg, bool):
43            args.append(c_bool_p(arg))
44        elif isinstance(arg, int):
45            args.append(c_int_p(arg))
46        elif isinstance(arg, float):
47            args.append(c_float_p(arg))
48        elif isinstance(arg, np.ndarray):
49            args.append(
50                ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
51            )
52        else:
53            raise NotImplementedError(arg)
54    return args
55
56
57class Mbarriers:
58    def __init__(self, number_of_barriers=1):
59        self.mbar_ty = ir.Type.parse(
60            "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
61            + str(number_of_barriers)
62            + ">"
63        )
64        self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
65        self.number_of_barriers = number_of_barriers
66
67    def __getitem__(self, key):
68        self.id_op = const(key)
69        return self
70
71    def init(self, count: int, predicate=None):
72        count_op = const(count)
73        if predicate is None:
74            nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
75        else:
76            nvgpu.mbarrier_init(
77                self.mbar_group_op, count_op, self.id_op, predicate=predicate
78            )
79
80    def arrive(self, txcount: int = 0, predicate=None):
81        if txcount != 0:
82            txcount_op = const(txcount)
83            nvgpu.mbarrier_arrive_expect_tx(
84                self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
85            )
86        else:
87            nvgpu.mbarrier_arrive(
88                ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
89            )
90
91    def try_wait(self, phase: bool = False, ticks: int = 10000000):
92        ticks_op = const(ticks)
93        phase_op = const(phase, T.bool())
94        nvgpu.MBarrierTryWaitParityOp(
95            self.mbar_group_op,
96            phase_op,
97            ticks_op,
98            mbarId=self.id_op,
99        )
100
101
102class TMA:
103    """A class that builds a TMA descriptor."""
104
105    def __init__(
106        self,
107        tma_box_shape,
108        memref_ty,
109        swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
110        l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
111        oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
112        interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
113    ):
114        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
115        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
116        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
117        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
118        self.tma_box_shape = tma_box_shape
119        self.memref_ty = memref_ty  # MemRefType
120        self.tma_memref = ir.MemRefType.get(tma_box_shape, memref_ty.element_type)
121
122    @property
123    def tensormap_descriptor_ty(self):
124        """Returns a tensormap descriptor type."""
125        tensorMemrefType = ir.MemRefType.get(
126            self.tma_box_shape,
127            self.memref_ty.element_type,
128            memory_space=ir.Attribute.parse("3"),
129        )
130        return nvgpu.TensorMapDescriptorType.get(
131            tensorMemrefType,
132            self.swizzle,
133            self.l2promo,
134            self.oob,
135            self.interleave,
136        )
137
138    def create_descriptor(self, device_ptr):
139        tma_descriptor_ty = self.tensormap_descriptor_ty
140        device_unranked_memref = memref.CastOp(
141            ir.UnrankedMemRefType.get(
142                self.memref_ty.element_type, self.memref_ty.memory_space
143            ),
144            device_ptr,
145        )
146        self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
147            tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
148        )
149        return self.tma_descriptor.result
150
151    def prefetch(self, predicate=None):
152        nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
153
154    def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
155        nvgpu.TmaAsyncLoadOp(
156            dest,
157            mbarrier.mbar_group_op,
158            self.tma_descriptor,
159            coordinates=map(const, coords),
160            mbarId=mbarrier.id_op,
161            predicate=predicate,
162        )
163
164
165WARP_GROUP_SIZE = 128  # Number of threads in a warpgroup
166
167
168class Warpgroup:
169    def __init__(self, primary_thread, register_size):
170        assert (primary_thread % WARP_GROUP_SIZE) == 0
171        tidx = gpu.thread_id(gpu.Dimension.x)
172        self.primary_thread = primary_thread
173        self.register_size = register_size
174        self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
175        self.wg_id = tidx / WARP_GROUP_SIZE
176        self.is_me = self.wg_id == (primary_thread // WARP_GROUP_SIZE)
177
178    def __enter__(self):
179        if_op = scf.IfOp(self.is_me)
180        self.ipoint_op = ir.InsertionPoint(if_op.then_block)
181        self.ipoint_op.__enter__()
182        if self.register_size < 64:
183            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease)
184        else:
185            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase)
186
187    def __exit__(self, exc_type, exc_value, traceback):
188        scf.yield_([])
189        self.ipoint_op.__exit__(exc_type, exc_value, traceback)
190        return True
191
192
193class WGMMAType(Enum):
194    Accumulator = 1
195    Descriptor = 2
196
197
198class WGMMAMatrix:
199    def __init__(
200        self,
201        matrix_type: WGMMAType,
202        shape: list = None,
203        desc: TMA = None,
204        smem=None,
205        ty=None,
206        acc_op=None,
207    ):
208        if acc_op is None:
209            self.M = shape[0]
210            self.N = shape[1]
211            self.ty = ty
212            self.matrix_type = matrix_type
213            self.desc = desc
214            self.smem = smem
215            if matrix_type is WGMMAType.Accumulator:
216                self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
217        elif acc_op:
218            self.acc_op = acc_op
219            self.matrix_type = WGMMAType.Accumulator
220
221    @property
222    def acc_ty(self):
223        parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
224        return ir.Type.parse(parse_str)
225
226    @property
227    def wgmma_ty(self):
228        parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
229        return ir.Type.parse(parse_str)
230
231    def store_accumulator(self, dest):
232        assert self.matrix_type == WGMMAType.Accumulator
233        nvgpu.warpgroup_mma_store(self.acc_op, dest)
234
235    def update_smem(self, smem):
236        self.smem = smem
237
238    def update_accumulator(self, acc_op):
239        self.acc_op = acc_op
240
241    def __matmul__(self, rhs):
242        lhs = nvgpu.warpgroup_generate_descriptor(
243            self.wgmma_ty, self.smem, self.desc.tma_descriptor
244        )
245        rhs = nvgpu.warpgroup_generate_descriptor(
246            rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor
247        )
248        return [lhs, rhs]
249
250    def __iadd__(self, matmulResult):
251        lhs = matmulResult[0]
252        rhs = matmulResult[1]
253        acc_op = nvgpu.WarpgroupMmaOp(
254            self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
255        )
256        return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op)
257
258
259def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
260    smem_space_str = "#gpu.address_space<workgroup>"
261    smem_space = ir.Attribute.parse(smem_space_str)
262    dynamic_smem = gpu.dynamic_shared_memory(
263        ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
264    )
265    if shape is None:
266        return dynamic_smem
267    memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
268    return memref.view(
269        ir.MemRefType.get(
270            memref_ty.shape, memref_ty.element_type, memory_space=smem_space
271        ),
272        dynamic_smem,
273        const(offset),
274        [],
275    )
276
277
278def get_mlir_ty(arg):
279    def get_mlir_ty_from_np(dtype):
280        if dtype == np.float16:
281            return T.f16()
282        if dtype == np.float32:
283            return T.f32()
284        if dtype == np.float64:
285            return T.f64()
286        if dtype == np.int32:
287            return T.i32()
288        if dtype == np.int64:
289            return T.i64()
290        raise NotImplementedError(dtype)
291
292    if isinstance(arg, bool):
293        return T.bool()
294    elif isinstance(arg, int):
295        return T.index()
296    elif isinstance(arg, float):
297        return T.f32()
298    elif isinstance(arg, np.ndarray):
299        descriptor = rt.get_ranked_memref_descriptor(arg)
300        dtype = get_mlir_ty_from_np(arg.dtype)
301        shape = descriptor.shape
302        return memref.MemRefType.get(shape, dtype)
303    raise NotImplementedError(arg)
304
305
306class NVDSL:
307    @staticmethod
308    def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
309        def decorator(func):
310            @functools.wraps(func)
311            def wrapper(*args, **kwargs):
312                launch_op = gpu.LaunchOp(
313                    None,
314                    [],
315                    *map(const, grid),
316                    *map(const, block),
317                    dynamicSharedMemorySize=arith.constant(T.i32(), smem),
318                )
319                launch_op.body.blocks.append(*([T.index()] * 12))
320                with ir.InsertionPoint(launch_op.body.blocks[0]):
321                    result = func(*args, **kwargs)
322                    gpu.terminator()
323                    return result
324
325            return wrapper
326
327        return decorator
328
329    @staticmethod
330    def mlir_func(funcBody):
331        @functools.wraps(funcBody)
332        def wrapper(*args, **kwargs):
333            function_name = funcBody.__name__
334
335            def saveIR(module):
336                """Save generated IR"""
337                if True:  # self.saveIR:
338                    # print(mlir_nvgpu_module)
339                    original_stdout = sys.stdout
340                    with open("nvdsl.mlir", "w") as f:
341                        sys.stdout = f
342                        print(module)
343                        sys.stdout = original_stdout
344
345            def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
346                """Generate MLIR's Arith dialects binary operations."""
347                rhs = const(rhs)
348                if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
349                    op += "F"
350                    if op.startswith("Cmp"):
351                        predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
352                            predAtt
353                        ]
354                elif arith._is_integer_like_type(
355                    lhs.type
356                ) and arith._is_integer_like_type(lhs.type):
357                    if op == "Div" or op == "Rem":
358                        op += "U"
359                    op += "I"
360                    if op.startswith("Cmp"):
361                        predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
362                            predAtt
363                        ]
364                else:
365                    raise NotImplementedError(
366                        f"Unsupported '{op}' operands: {lhs}, {rhs}"
367                    )
368
369                if op.startswith("Cmp"):
370                    op = getattr(arith, f"{op}Op")
371
372                    return op(predicateAttr, lhs, rhs).result
373                else:
374                    op = getattr(arith, f"{op}Op")
375                    return op(lhs, rhs).result
376
377            @ir.register_value_caster(ir.IndexType.static_typeid)
378            @ir.register_value_caster(ir.F32Type.static_typeid)
379            @ir.register_value_caster(ir.F16Type.static_typeid)
380            @ir.register_value_caster(ir.F64Type.static_typeid)
381            @ir.register_value_caster(ir.IntegerType.static_typeid)
382            class ArithValue(ir.Value):
383                """Overloads operators for MLIR's Arith dialects binary operations."""
384
385                def __init__(self, v):
386                    super().__init__(v)
387
388                __add__ = partialmethod(_binary_op, op="Add")
389                __sub__ = partialmethod(_binary_op, op="Sub")
390                __mul__ = partialmethod(_binary_op, op="Mul")
391                __truediv__ = partialmethod(_binary_op, op="Div")
392                __mod__ = partialmethod(_binary_op, op="Rem")
393                __xor__ = partialmethod(_binary_op, op="XOr")
394                __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
395                __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
396                __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
397                __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
398                __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
399                __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
400                __and__ = partialmethod(_binary_op, op="And")
401                __or__ = partialmethod(_binary_op, op="Or")
402
403                def __str__(self):
404                    return (
405                        super()
406                        .__str__()
407                        .replace(ir.Value.__name__, ArithValue.__name__)
408                    )
409
410            # Generate MLIR Context and start generating IR
411            with ir.Context(), ir.Location.unknown():
412                types = []
413                for arg in args:
414                    types.append(get_mlir_ty(arg))
415
416                # Build IR
417                module = ir.Module.create()
418                with ir.InsertionPoint(module.body):
419                    fop = func.FuncOp(function_name, (types, []))
420                    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
421                    with ir.InsertionPoint(fop.add_entry_block()):
422                        fargs = []
423                        for i, a in enumerate(types):
424                            fargs.append(fop.arguments[i])
425
426                        # Call user function body
427                        result = funcBody(*fargs, **kwargs)
428                        func.ReturnOp([])
429
430                # Save IR in a file
431                # saveIR(module)
432
433                # Verify the module
434                module.operation.verify()
435
436                # Compile and JIT MLIR module
437                options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
438                support_lib = os.getenv("SUPPORT_LIB")
439                if not os.path.exists(support_lib):
440                    raise FileNotFoundError(
441                        errno.ENOENT, os.strerror(errno.ENOENT), support_lib
442                    )
443                compiler = nvgpucompiler.NvgpuCompiler(
444                    options, opt_level=3, shared_libs=[support_lib]
445                )
446                engine = compiler.compile_and_jit(module)
447
448            # Convert input arguments to MLIR arguments
449            newArgs = get_mlir_func_obj_ty(args)
450
451            # Run the compiled program
452            engine.invoke(function_name, *newArgs)
453
454            return result
455
456        return wrapper
457