xref: /llvm-project/mlir/test/python/dialects/transform_nvgpu_ext.py (revision bc30b415cadc477f229b8f3143979c41fe556a44)
1*bc30b415SOleksandr "Alex" Zinenko# RUN: %PYTHON %s | FileCheck %s
2*bc30b415SOleksandr "Alex" Zinenko
3*bc30b415SOleksandr "Alex" Zinenkofrom mlir.ir import *
4*bc30b415SOleksandr "Alex" Zinenkofrom mlir.dialects import transform
5*bc30b415SOleksandr "Alex" Zinenkofrom mlir.dialects.transform import nvgpu
6*bc30b415SOleksandr "Alex" Zinenko
7*bc30b415SOleksandr "Alex" Zinenko
8*bc30b415SOleksandr "Alex" Zinenkodef run(f):
9*bc30b415SOleksandr "Alex" Zinenko    with Context(), Location.unknown():
10*bc30b415SOleksandr "Alex" Zinenko        module = Module.create()
11*bc30b415SOleksandr "Alex" Zinenko        with InsertionPoint(module.body):
12*bc30b415SOleksandr "Alex" Zinenko            print("\nTEST:", f.__name__)
13*bc30b415SOleksandr "Alex" Zinenko            f()
14*bc30b415SOleksandr "Alex" Zinenko        print(module)
15*bc30b415SOleksandr "Alex" Zinenko    return f
16*bc30b415SOleksandr "Alex" Zinenko
17*bc30b415SOleksandr "Alex" Zinenko
18*bc30b415SOleksandr "Alex" Zinenko@run
19*bc30b415SOleksandr "Alex" Zinenkodef testCreateAsyncGroups():
20*bc30b415SOleksandr "Alex" Zinenko    sequence = transform.SequenceOp(
21*bc30b415SOleksandr "Alex" Zinenko        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
22*bc30b415SOleksandr "Alex" Zinenko    )
23*bc30b415SOleksandr "Alex" Zinenko    with InsertionPoint(sequence.body):
24*bc30b415SOleksandr "Alex" Zinenko        nvgpu.CreateAsyncGroupsOp(transform.AnyOpType.get(), sequence.bodyTarget)
25*bc30b415SOleksandr "Alex" Zinenko        transform.YieldOp()
26*bc30b415SOleksandr "Alex" Zinenko    # CHECK-LABEL: TEST: testCreateAsyncGroups
27*bc30b415SOleksandr "Alex" Zinenko    # CHECK: transform.nvgpu.create_async_groups
28