xref: /llvm-project/mlir/test/python/dialects/transform_sparse_tensor_ext.py (revision 3d27d1152eacd6432485cd81d471bb03987a83e1)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects.transform import sparse_tensor
6
7
8def run(f):
9    with Context(), Location.unknown():
10        module = Module.create()
11        with InsertionPoint(module.body):
12            sequence = transform.SequenceOp(
13                transform.FailurePropagationMode.Propagate,
14                [],
15                transform.AnyOpType.get(),
16            )
17            with InsertionPoint(sequence.body):
18                f(sequence.bodyTarget)
19                transform.YieldOp()
20        print("\nTEST:", f.__name__)
21        print(module)
22    return f
23
24
25@run
26def testMatchSparseInOut(target):
27    sparse_tensor.MatchSparseInOut(transform.AnyOpType.get(), target)
28    # CHECK-LABEL: TEST: testMatchSparseInOut
29    # CHECK:       transform.sequence
30    # CHECK-NEXT:  ^{{.*}}(%[[ARG0:.*]]: !transform.any_op):
31    # CHECK-NEXT:    transform.sparse_tensor.match.sparse_inout %[[ARG0]]
32