1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects.transform import pdl as transform_pdl 6 7 8def run(f): 9 with Context(), Location.unknown(): 10 module = Module.create() 11 with InsertionPoint(module.body): 12 print("\nTEST:", f.__name__) 13 f(module) 14 print(module) 15 return f 16 17 18@run 19def testTypes(module: Module): 20 # CHECK-LABEL: TEST: testTypes 21 # CHECK: !transform.any_op 22 any_op = transform.AnyOpType.get() 23 print(any_op) 24 25 # CHECK: !transform.any_param 26 any_param = transform.AnyParamType.get() 27 print(any_param) 28 29 # CHECK: !transform.any_value 30 any_value = transform.AnyValueType.get() 31 print(any_value) 32 33 # CHECK: !transform.op<"foo.bar"> 34 # CHECK: foo.bar 35 concrete_op = transform.OperationType.get("foo.bar") 36 print(concrete_op) 37 print(concrete_op.operation_name) 38 39 # CHECK: !transform.param<i32> 40 # CHECK: i32 41 param = transform.ParamType.get(IntegerType.get_signless(32)) 42 print(param) 43 print(param.type) 44 45 46@run 47def testSequenceOp(module: Module): 48 sequence = transform.SequenceOp( 49 transform.FailurePropagationMode.Propagate, 50 [transform.AnyOpType.get()], 51 transform.AnyOpType.get(), 52 ) 53 with InsertionPoint(sequence.body): 54 transform.YieldOp([sequence.bodyTarget]) 55 # CHECK-LABEL: TEST: testSequenceOp 56 # CHECK: = transform.sequence -> !transform.any_op failures(propagate) { 57 # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): 58 # CHECK: yield %[[ARG0]] : !transform.any_op 59 # CHECK: } 60 61@run 62def testNestedSequenceOp(module: Module): 63 sequence = transform.SequenceOp( 64 transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() 65 ) 66 with InsertionPoint(sequence.body): 67 nested = transform.SequenceOp( 68 transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget 69 ) 70 with InsertionPoint(nested.body): 71 doubly_nested = transform.SequenceOp( 72 transform.FailurePropagationMode.Propagate, 73 [transform.AnyOpType.get()], 74 nested.bodyTarget, 75 ) 76 with InsertionPoint(doubly_nested.body): 77 transform.YieldOp([doubly_nested.bodyTarget]) 78 transform.YieldOp() 79 transform.YieldOp() 80 # CHECK-LABEL: TEST: testNestedSequenceOp 81 # CHECK: transform.sequence failures(propagate) { 82 # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): 83 # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) { 84 # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): 85 # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) { 86 # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op): 87 # CHECK: yield %[[ARG2]] : !transform.any_op 88 # CHECK: } 89 # CHECK: } 90 # CHECK: } 91 92 93@run 94def testSequenceOpWithExtras(module: Module): 95 sequence = transform.SequenceOp( 96 transform.FailurePropagationMode.Propagate, 97 [], 98 transform.AnyOpType.get(), 99 [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], 100 ) 101 with InsertionPoint(sequence.body): 102 transform.YieldOp() 103 # CHECK-LABEL: TEST: testSequenceOpWithExtras 104 # CHECK: transform.sequence failures(propagate) 105 # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): 106 107 108@run 109def testNestedSequenceOpWithExtras(module: Module): 110 sequence = transform.SequenceOp( 111 transform.FailurePropagationMode.Propagate, 112 [], 113 transform.AnyOpType.get(), 114 [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], 115 ) 116 with InsertionPoint(sequence.body): 117 nested = transform.SequenceOp( 118 transform.FailurePropagationMode.Propagate, 119 [], 120 sequence.bodyTarget, 121 sequence.bodyExtraArgs, 122 ) 123 with InsertionPoint(nested.body): 124 transform.YieldOp() 125 transform.YieldOp() 126 # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras 127 # CHECK: transform.sequence failures(propagate) 128 # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): 129 # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) 130 131 132@run 133def testTransformPDLOps(module: Module): 134 withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) 135 with InsertionPoint(withPdl.body): 136 sequence = transform.SequenceOp( 137 transform.FailurePropagationMode.Propagate, 138 [transform.AnyOpType.get()], 139 withPdl.bodyTarget, 140 ) 141 with InsertionPoint(sequence.body): 142 match = transform_pdl.PDLMatchOp( 143 transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" 144 ) 145 transform.YieldOp(match) 146 # CHECK-LABEL: TEST: testTransformPDLOps 147 # CHECK: transform.with_pdl_patterns { 148 # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): 149 # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { 150 # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): 151 # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] 152 # CHECK: yield %[[RES]] : !transform.any_op 153 # CHECK: } 154 # CHECK: } 155 156 157@run 158def testNamedSequenceOp(module: Module): 159 module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get() 160 named_sequence = transform.NamedSequenceOp( 161 "__transform_main", 162 [transform.AnyOpType.get()], 163 [transform.AnyOpType.get()], 164 arg_attrs = [{"transform.consumed": UnitAttr.get()}]) 165 with InsertionPoint(named_sequence.body): 166 transform.YieldOp([named_sequence.bodyTarget]) 167 # CHECK-LABEL: TEST: testNamedSequenceOp 168 # CHECK: module attributes {transform.with_named_sequence} { 169 # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { 170 # CHECK: yield %[[ARG0]] : !transform.any_op 171 172 173@run 174def testGetParentOp(module: Module): 175 sequence = transform.SequenceOp( 176 transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() 177 ) 178 with InsertionPoint(sequence.body): 179 transform.GetParentOp( 180 transform.AnyOpType.get(), 181 sequence.bodyTarget, 182 isolated_from_above=True, 183 nth_parent=2, 184 ) 185 transform.YieldOp() 186 # CHECK-LABEL: TEST: testGetParentOp 187 # CHECK: transform.sequence 188 # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): 189 # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64} 190 191 192@run 193def testMergeHandlesOp(module: Module): 194 sequence = transform.SequenceOp( 195 transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() 196 ) 197 with InsertionPoint(sequence.body): 198 transform.MergeHandlesOp([sequence.bodyTarget]) 199 transform.YieldOp() 200 # CHECK-LABEL: TEST: testMergeHandlesOp 201 # CHECK: transform.sequence 202 # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): 203 # CHECK: = merge_handles %[[ARG1]] 204 205 206@run 207def testApplyPatternsOpCompact(module: Module): 208 sequence = transform.SequenceOp( 209 transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() 210 ) 211 with InsertionPoint(sequence.body): 212 with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): 213 transform.ApplyCanonicalizationPatternsOp() 214 transform.YieldOp() 215 # CHECK-LABEL: TEST: testApplyPatternsOpCompact 216 # CHECK: apply_patterns to 217 # CHECK: transform.apply_patterns.canonicalization 218 # CHECK: !transform.any_op 219 220 221@run 222def testApplyPatternsOpWithType(module: Module): 223 sequence = transform.SequenceOp( 224 transform.FailurePropagationMode.Propagate, [], 225 transform.OperationType.get('test.dummy') 226 ) 227 with InsertionPoint(sequence.body): 228 with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): 229 transform.ApplyCanonicalizationPatternsOp() 230 transform.YieldOp() 231 # CHECK-LABEL: TEST: testApplyPatternsOp 232 # CHECK: apply_patterns to 233 # CHECK: transform.apply_patterns.canonicalization 234 # CHECK: !transform.op<"test.dummy"> 235 236 237@run 238def testReplicateOp(module: Module): 239 with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) 240 with InsertionPoint(with_pdl.body): 241 sequence = transform.SequenceOp( 242 transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget 243 ) 244 with InsertionPoint(sequence.body): 245 m1 = transform_pdl.PDLMatchOp( 246 transform.AnyOpType.get(), sequence.bodyTarget, "first" 247 ) 248 m2 = transform_pdl.PDLMatchOp( 249 transform.AnyOpType.get(), sequence.bodyTarget, "second" 250 ) 251 transform.ReplicateOp(m1, [m2]) 252 transform.YieldOp() 253 # CHECK-LABEL: TEST: testReplicateOp 254 # CHECK: %[[FIRST:.+]] = pdl_match 255 # CHECK: %[[SECOND:.+]] = pdl_match 256 # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] 257