1# RUN: %PYTHON %s | FileCheck %s 2 3from typing import Callable 4from mlir import ir 5from mlir.dialects import scf, pdl 6from mlir.dialects.transform import ( 7 structured, 8 get_parent_op, 9 apply_patterns_canonicalization, 10 apply_cse, 11 any_op_t, 12) 13from mlir.dialects.transform import FailurePropagationMode 14from mlir.dialects.transform.structured import structured_match 15from mlir.dialects.transform.loop import loop_unroll 16from mlir.dialects.transform.extras import ( 17 constant_param, 18 OpHandle, 19 insert_transform_script, 20 sequence, 21 apply_patterns, 22) 23from mlir.extras import types as T 24 25 26def construct_and_print_in_module(f): 27 print("\nTEST:", f.__name__) 28 with ir.Context(), ir.Location.unknown(): 29 module = ir.Module.create() 30 with ir.InsertionPoint(module.body): 31 f() 32 print(module) 33 return f 34 35 36def build_transform_script(script: Callable[[OpHandle], None]): 37 print("\nTEST:", script.__name__) 38 with ir.Context(), ir.Location.unknown(): 39 module = ir.Module.create() 40 module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() 41 insert_transform_script(module.body, script=script, dump_script=True) 42 module.operation.verify() 43 44 45def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]): 46 print("\nTEST:", script.__name__) 47 with ir.Context(), ir.Location.unknown(): 48 module = ir.Module.create() 49 module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() 50 insert_transform_script( 51 ir.InsertionPoint.at_block_begin(module.body), 52 script=script, 53 dump_script=True, 54 ) 55 module.operation.verify() 56 57 58# CHECK-LABEL: TEST: test_build_script_at_insertion_point 59@build_transform_script_at_insertion_point 60def test_build_script_at_insertion_point(op: OpHandle): 61 pass 62 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 63 # CHECK-NEXT: transform.yield 64 # CHECK-NEXT: } 65 66 67# CHECK-LABEL: TEST: test_constant_param_int 68@build_transform_script 69def test_constant_param_int(_: OpHandle): 70 constant_param(ir.IntegerAttr.get(T.i32(), 42)) 71 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 72 # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32 73 # CHECK-SAME: !transform.param<i32> 74 75 76# CHECK-LABEL: TEST: test_constant_param_py_int 77@build_transform_script 78def test_constant_param_py_int(_: OpHandle): 79 constant_param(42) 80 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 81 # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64 82 # CHECK-SAME: !transform.param<i64> 83 84 85# CHECK-LABEL: TEST: test_constant_param_symbol_attr 86@build_transform_script 87def test_constant_param_symbol_attr(_: OpHandle): 88 constant_param(ir.SymbolRefAttr.get(["symbol"])) 89 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 90 # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol 91 # CHECK-SAME: !transform.any_param 92 93 94# CHECK-LABEL: TEST: test_constant_param_type 95@build_transform_script 96def test_constant_param_type(_: OpHandle): 97 constant_param(ir.TypeAttr.get(T.i32())) 98 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 99 # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32 100 # CHECK-SAME: !transform.any_param 101 102 103# CHECK-LABEL: TEST: test_get_defining_op 104@build_transform_script 105def test_get_defining_op(op: OpHandle): 106 op.get_result().get_defining_op() 107 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 108 # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0] 109 # CHECK-SAME: !transform.any_value 110 # CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]] 111 112 113# CHECK-LABEL: TEST: test_get_result 114@build_transform_script 115def test_get_result(op: OpHandle): 116 op.get_result() 117 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 118 # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0] 119 120 121# CHECK-LABEL: TEST: test_match_ops_single 122@build_transform_script 123def test_match_ops_single(op: OpHandle): 124 op.match_ops(scf.ForOp) 125 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 126 # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]} 127 # CHECK-SAME: in %[[VAL_0]] 128 # CHECK-SAME: -> !transform.op<"scf.for"> 129 130 131# CHECK-LABEL: TEST: test_match_ops_string_name 132@build_transform_script 133def test_match_ops_string_name(op: OpHandle): 134 op.match_ops("linalg.matmul") 135 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 136 # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 137 # CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]] 138 139 140# CHECK-LABEL: TEST: test_match_ops_string_iface 141@build_transform_script 142def test_match_ops_string_iface(op: OpHandle): 143 op.match_ops("LinalgOp") 144 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 145 # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 146 # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] 147 148 149# CHECK-LABEL: TEST: test_match_ops_iface 150@build_transform_script 151def test_match_ops_iface(op: OpHandle): 152 op.match_ops(structured.MatchInterfaceEnum.LinalgOp) 153 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 154 # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 155 # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] 156 157 158# CHECK-LABEL: TEST: test_match_ops_multiple 159@build_transform_script 160def test_match_ops_multiple(op: OpHandle): 161 op.match_ops([scf.ForOp, scf.ForallOp]) 162 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 163 # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 164 # CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]] 165 # CHECK-SAME: -> !transform.any_op 166 167 168# CHECK-LABEL: TEST: test_match_ops_mixed 169@build_transform_script 170def test_match_ops_mixed(op: OpHandle): 171 op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp]) 172 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 173 # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match 174 # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]] 175 # CHECK-SAME: -> !transform.any_op 176 177 178# CHECK-LABEL: TEST: test_print_message 179@build_transform_script 180def test_print_message(op: OpHandle): 181 op.print("message") 182 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 183 # CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"} 184 185 186# CHECK-LABEL: TEST: test_print_plain 187@build_transform_script 188def test_print_plain(op: OpHandle): 189 op.print() 190 # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { 191 # CHECK-NEXT: transform.print %[[VAL_0]] 192 193 194# CHECK-LABEL: TEST: test_sequence_region 195@construct_and_print_in_module 196def test_sequence_region(): 197 # CHECK: transform.sequence failures(propagate) { 198 # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): 199 # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 200 # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation 201 # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation 202 # CHECK: } 203 @sequence([], FailurePropagationMode.Propagate, []) 204 def basic(target: any_op_t()): 205 m = structured_match(any_op_t(), target, ops=["arith.addi"]) 206 loop = get_parent_op(pdl.op_t(), m, op_name="scf.for") 207 loop_unroll(loop, 4) 208 209 210# CHECK-LABEL: TEST: test_apply_patterns 211@construct_and_print_in_module 212def test_apply_patterns(): 213 # CHECK: transform.sequence failures(propagate) { 214 # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): 215 # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 216 # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation 217 # CHECK: apply_patterns to %[[VAL_2]] { 218 # CHECK: transform.apply_patterns.canonicalization 219 # CHECK: } : !pdl.operation 220 # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 221 # CHECK: apply_cse to %[[VAL_3]] : !transform.any_op 222 # CHECK: } 223 @sequence([], FailurePropagationMode.Propagate, []) 224 def basic(variant_op: any_op_t()): 225 matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"]) 226 top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func") 227 228 @apply_patterns(top_func) 229 def pats(): 230 apply_patterns_canonicalization() 231 232 top_func = structured_match(any_op_t(), variant_op, ops=["func.func"]) 233 apply_cse(top_func) 234