1681eacc1Smartin-luecke# RUN: %PYTHON %s | FileCheck %s 2681eacc1Smartin-luecke 3681eacc1Smartin-lueckefrom typing import Callable 4681eacc1Smartin-lueckefrom mlir import ir 5537b2aa2SMaksim Leventalfrom mlir.dialects import scf, pdl 6537b2aa2SMaksim Leventalfrom mlir.dialects.transform import ( 7537b2aa2SMaksim Levental structured, 8537b2aa2SMaksim Levental get_parent_op, 9537b2aa2SMaksim Levental apply_patterns_canonicalization, 10537b2aa2SMaksim Levental apply_cse, 11537b2aa2SMaksim Levental any_op_t, 12537b2aa2SMaksim Levental) 13537b2aa2SMaksim Leventalfrom mlir.dialects.transform import FailurePropagationMode 14537b2aa2SMaksim Leventalfrom mlir.dialects.transform.structured import structured_match 15537b2aa2SMaksim Leventalfrom mlir.dialects.transform.loop import loop_unroll 16537b2aa2SMaksim Leventalfrom mlir.dialects.transform.extras import ( 17*06e3abcbSmartin-luecke constant_param, 18537b2aa2SMaksim Levental OpHandle, 19537b2aa2SMaksim Levental insert_transform_script, 20537b2aa2SMaksim Levental sequence, 21537b2aa2SMaksim Levental apply_patterns, 22537b2aa2SMaksim Levental) 23537b2aa2SMaksim Leventalfrom mlir.extras import types as T 24537b2aa2SMaksim Levental 25537b2aa2SMaksim Levental 26537b2aa2SMaksim Leventaldef construct_and_print_in_module(f): 27537b2aa2SMaksim Levental print("\nTEST:", f.__name__) 28537b2aa2SMaksim Levental with ir.Context(), ir.Location.unknown(): 29537b2aa2SMaksim Levental module = ir.Module.create() 30537b2aa2SMaksim Levental with ir.InsertionPoint(module.body): 31537b2aa2SMaksim Levental f() 32537b2aa2SMaksim Levental print(module) 33537b2aa2SMaksim Levental return f 34681eacc1Smartin-luecke 35681eacc1Smartin-luecke 36681eacc1Smartin-lueckedef build_transform_script(script: Callable[[OpHandle], None]): 37681eacc1Smartin-luecke print("\nTEST:", script.__name__) 38681eacc1Smartin-luecke with ir.Context(), ir.Location.unknown(): 39681eacc1Smartin-luecke module = ir.Module.create() 40681eacc1Smartin-luecke module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() 41681eacc1Smartin-luecke insert_transform_script(module.body, script=script, dump_script=True) 42681eacc1Smartin-luecke module.operation.verify() 43681eacc1Smartin-luecke 44681eacc1Smartin-luecke 45681eacc1Smartin-lueckedef build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]): 46681eacc1Smartin-luecke print("\nTEST:", script.__name__) 47681eacc1Smartin-luecke with ir.Context(), ir.Location.unknown(): 48681eacc1Smartin-luecke module = ir.Module.create() 49681eacc1Smartin-luecke module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() 50681eacc1Smartin-luecke insert_transform_script( 51681eacc1Smartin-luecke ir.InsertionPoint.at_block_begin(module.body), 52681eacc1Smartin-luecke script=script, 53681eacc1Smartin-luecke dump_script=True, 54681eacc1Smartin-luecke ) 55681eacc1Smartin-luecke module.operation.verify() 56681eacc1Smartin-luecke 57681eacc1Smartin-luecke 58681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_build_script_at_insertion_point 59681eacc1Smartin-luecke@build_transform_script_at_insertion_point 60681eacc1Smartin-lueckedef test_build_script_at_insertion_point(op: OpHandle): 61681eacc1Smartin-luecke pass 62681eacc1Smartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 63681eacc1Smartin-luecke # CHECK-NEXT: transform.yield 64681eacc1Smartin-luecke # CHECK-NEXT: } 65681eacc1Smartin-luecke 66681eacc1Smartin-luecke 67*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_constant_param_int 68*06e3abcbSmartin-luecke@build_transform_script 69*06e3abcbSmartin-lueckedef test_constant_param_int(_: OpHandle): 70*06e3abcbSmartin-luecke constant_param(ir.IntegerAttr.get(T.i32(), 42)) 71*06e3abcbSmartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 72*06e3abcbSmartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32 73*06e3abcbSmartin-luecke # CHECK-SAME: !transform.param<i32> 74*06e3abcbSmartin-luecke 75*06e3abcbSmartin-luecke 76*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_constant_param_py_int 77*06e3abcbSmartin-luecke@build_transform_script 78*06e3abcbSmartin-lueckedef test_constant_param_py_int(_: OpHandle): 79*06e3abcbSmartin-luecke constant_param(42) 80*06e3abcbSmartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 81*06e3abcbSmartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64 82*06e3abcbSmartin-luecke # CHECK-SAME: !transform.param<i64> 83*06e3abcbSmartin-luecke 84*06e3abcbSmartin-luecke 85*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_constant_param_symbol_attr 86*06e3abcbSmartin-luecke@build_transform_script 87*06e3abcbSmartin-lueckedef test_constant_param_symbol_attr(_: OpHandle): 88*06e3abcbSmartin-luecke constant_param(ir.SymbolRefAttr.get(["symbol"])) 89*06e3abcbSmartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 90*06e3abcbSmartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol 91*06e3abcbSmartin-luecke # CHECK-SAME: !transform.any_param 92*06e3abcbSmartin-luecke 93*06e3abcbSmartin-luecke 94*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_constant_param_type 95*06e3abcbSmartin-luecke@build_transform_script 96*06e3abcbSmartin-lueckedef test_constant_param_type(_: OpHandle): 97*06e3abcbSmartin-luecke constant_param(ir.TypeAttr.get(T.i32())) 98*06e3abcbSmartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 99*06e3abcbSmartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32 100*06e3abcbSmartin-luecke # CHECK-SAME: !transform.any_param 101*06e3abcbSmartin-luecke 102*06e3abcbSmartin-luecke 103*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_get_defining_op 104*06e3abcbSmartin-luecke@build_transform_script 105*06e3abcbSmartin-lueckedef test_get_defining_op(op: OpHandle): 106*06e3abcbSmartin-luecke op.get_result().get_defining_op() 107*06e3abcbSmartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 108*06e3abcbSmartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0] 109*06e3abcbSmartin-luecke # CHECK-SAME: !transform.any_value 110*06e3abcbSmartin-luecke # CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]] 111*06e3abcbSmartin-luecke 112*06e3abcbSmartin-luecke 113*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_get_result 114*06e3abcbSmartin-luecke@build_transform_script 115*06e3abcbSmartin-lueckedef test_get_result(op: OpHandle): 116*06e3abcbSmartin-luecke op.get_result() 117*06e3abcbSmartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 118*06e3abcbSmartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0] 119*06e3abcbSmartin-luecke 120*06e3abcbSmartin-luecke 121681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_single 122681eacc1Smartin-luecke@build_transform_script 123681eacc1Smartin-lueckedef test_match_ops_single(op: OpHandle): 124681eacc1Smartin-luecke op.match_ops(scf.ForOp) 125681eacc1Smartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 126681eacc1Smartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]} 127681eacc1Smartin-luecke # CHECK-SAME: in %[[VAL_0]] 128681eacc1Smartin-luecke # CHECK-SAME: -> !transform.op<"scf.for"> 129681eacc1Smartin-luecke 130681eacc1Smartin-luecke 131681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_string_name 132681eacc1Smartin-luecke@build_transform_script 133681eacc1Smartin-lueckedef test_match_ops_string_name(op: OpHandle): 134681eacc1Smartin-luecke op.match_ops("linalg.matmul") 135681eacc1Smartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 136681eacc1Smartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 137681eacc1Smartin-luecke # CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]] 138681eacc1Smartin-luecke 139681eacc1Smartin-luecke 140681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_string_iface 141681eacc1Smartin-luecke@build_transform_script 142681eacc1Smartin-lueckedef test_match_ops_string_iface(op: OpHandle): 143681eacc1Smartin-luecke op.match_ops("LinalgOp") 144681eacc1Smartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 145681eacc1Smartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 146681eacc1Smartin-luecke # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] 147681eacc1Smartin-luecke 148681eacc1Smartin-luecke 149681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_iface 150681eacc1Smartin-luecke@build_transform_script 151681eacc1Smartin-lueckedef test_match_ops_iface(op: OpHandle): 152681eacc1Smartin-luecke op.match_ops(structured.MatchInterfaceEnum.LinalgOp) 153681eacc1Smartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 154681eacc1Smartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 155681eacc1Smartin-luecke # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] 156681eacc1Smartin-luecke 157681eacc1Smartin-luecke 158681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_multiple 159681eacc1Smartin-luecke@build_transform_script 160681eacc1Smartin-lueckedef test_match_ops_multiple(op: OpHandle): 161681eacc1Smartin-luecke op.match_ops([scf.ForOp, scf.ForallOp]) 162681eacc1Smartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 163681eacc1Smartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 164681eacc1Smartin-luecke # CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]] 165681eacc1Smartin-luecke # CHECK-SAME: -> !transform.any_op 166681eacc1Smartin-luecke 167681eacc1Smartin-luecke 168681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_mixed 169681eacc1Smartin-luecke@build_transform_script 170681eacc1Smartin-lueckedef test_match_ops_mixed(op: OpHandle): 171681eacc1Smartin-luecke op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp]) 172681eacc1Smartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 173681eacc1Smartin-luecke # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 174681eacc1Smartin-luecke # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]] 175681eacc1Smartin-luecke # CHECK-SAME: -> !transform.any_op 176537b2aa2SMaksim Levental 177537b2aa2SMaksim Levental 178*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_print_message 179*06e3abcbSmartin-luecke@build_transform_script 180*06e3abcbSmartin-lueckedef test_print_message(op: OpHandle): 181*06e3abcbSmartin-luecke op.print("message") 182*06e3abcbSmartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 183*06e3abcbSmartin-luecke # CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"} 184*06e3abcbSmartin-luecke 185*06e3abcbSmartin-luecke 186*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_print_plain 187*06e3abcbSmartin-luecke@build_transform_script 188*06e3abcbSmartin-lueckedef test_print_plain(op: OpHandle): 189*06e3abcbSmartin-luecke op.print() 190*06e3abcbSmartin-luecke # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 191*06e3abcbSmartin-luecke # CHECK-NEXT: transform.print %[[VAL_0]] 192*06e3abcbSmartin-luecke 193*06e3abcbSmartin-luecke 194537b2aa2SMaksim Levental# CHECK-LABEL: TEST: test_sequence_region 195537b2aa2SMaksim Levental@construct_and_print_in_module 196537b2aa2SMaksim Leventaldef test_sequence_region(): 197537b2aa2SMaksim Levental # CHECK: transform.sequence failures(propagate) { 198537b2aa2SMaksim Levental # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): 199537b2aa2SMaksim Levental # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 200537b2aa2SMaksim Levental # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation 201537b2aa2SMaksim Levental # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation 202537b2aa2SMaksim Levental # CHECK: } 203537b2aa2SMaksim Levental @sequence([], FailurePropagationMode.Propagate, []) 204537b2aa2SMaksim Levental def basic(target: any_op_t()): 205537b2aa2SMaksim Levental m = structured_match(any_op_t(), target, ops=["arith.addi"]) 206537b2aa2SMaksim Levental loop = get_parent_op(pdl.op_t(), m, op_name="scf.for") 207537b2aa2SMaksim Levental loop_unroll(loop, 4) 208537b2aa2SMaksim Levental 209537b2aa2SMaksim Levental 210537b2aa2SMaksim Levental# CHECK-LABEL: TEST: test_apply_patterns 211537b2aa2SMaksim Levental@construct_and_print_in_module 212537b2aa2SMaksim Leventaldef test_apply_patterns(): 213537b2aa2SMaksim Levental # CHECK: transform.sequence failures(propagate) { 214537b2aa2SMaksim Levental # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): 215537b2aa2SMaksim Levental # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 216537b2aa2SMaksim Levental # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation 217537b2aa2SMaksim Levental # CHECK: apply_patterns to %[[VAL_2]] { 218537b2aa2SMaksim Levental # CHECK: transform.apply_patterns.canonicalization 219537b2aa2SMaksim Levental # CHECK: } : !pdl.operation 220537b2aa2SMaksim Levental # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 221537b2aa2SMaksim Levental # CHECK: apply_cse to %[[VAL_3]] : !transform.any_op 222537b2aa2SMaksim Levental # CHECK: } 223537b2aa2SMaksim Levental @sequence([], FailurePropagationMode.Propagate, []) 224537b2aa2SMaksim Levental def basic(variant_op: any_op_t()): 225537b2aa2SMaksim Levental matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"]) 226537b2aa2SMaksim Levental top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func") 227537b2aa2SMaksim Levental 228537b2aa2SMaksim Levental @apply_patterns(top_func) 229537b2aa2SMaksim Levental def pats(): 230537b2aa2SMaksim Levental apply_patterns_canonicalization() 231537b2aa2SMaksim Levental 232537b2aa2SMaksim Levental top_func = structured_match(any_op_t(), variant_op, ops=["func.func"]) 233537b2aa2SMaksim Levental apply_cse(top_func) 234