xref: /llvm-project/mlir/test/python/dialects/transform_nvgpu_ext.py (revision bc30b415cadc477f229b8f3143979c41fe556a44)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects.transform import nvgpu
6
7
8def run(f):
9    with Context(), Location.unknown():
10        module = Module.create()
11        with InsertionPoint(module.body):
12            print("\nTEST:", f.__name__)
13            f()
14        print(module)
15    return f
16
17
18@run
19def testCreateAsyncGroups():
20    sequence = transform.SequenceOp(
21        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
22    )
23    with InsertionPoint(sequence.body):
24        nvgpu.CreateAsyncGroupsOp(transform.AnyOpType.get(), sequence.bodyTarget)
25        transform.YieldOp()
26    # CHECK-LABEL: TEST: testCreateAsyncGroups
27    # CHECK: transform.nvgpu.create_async_groups
28