# RUN: %PYTHON %s | FileCheck %s from mlir.ir import * from mlir.dialects import transform from mlir.dialects.transform import sparse_tensor def run(f): with Context(), Location.unknown(): module = Module.create() with InsertionPoint(module.body): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), ) with InsertionPoint(sequence.body): f(sequence.bodyTarget) transform.YieldOp() print("\nTEST:", f.__name__) print(module) return f @run def testMatchSparseInOut(target): sparse_tensor.MatchSparseInOut(transform.AnyOpType.get(), target) # CHECK-LABEL: TEST: testMatchSparseInOut # CHECK: transform.sequence # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): # CHECK-NEXT: transform.sparse_tensor.match.sparse_inout %[[ARG0]]