1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects.transform import sparse_tensor 6 7 8def run(f): 9 with Context(), Location.unknown(): 10 module = Module.create() 11 with InsertionPoint(module.body): 12 sequence = transform.SequenceOp( 13 transform.FailurePropagationMode.Propagate, 14 [], 15 transform.AnyOpType.get(), 16 ) 17 with InsertionPoint(sequence.body): 18 f(sequence.bodyTarget) 19 transform.YieldOp() 20 print("\nTEST:", f.__name__) 21 print(module) 22 return f 23 24 25@run 26def testMatchSparseInOut(target): 27 sparse_tensor.MatchSparseInOut(transform.AnyOpType.get(), target) 28 # CHECK-LABEL: TEST: testMatchSparseInOut 29 # CHECK: transform.sequence 30 # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): 31 # CHECK-NEXT: transform.sparse_tensor.match.sparse_inout %[[ARG0]] 32