1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects.pdl import * 5 6 7def constructAndPrintInModule(f): 8 print("\nTEST:", f.__name__) 9 with Context(), Location.unknown(): 10 module = Module.create() 11 with InsertionPoint(module.body): 12 f() 13 print(module) 14 return f 15 16 17# CHECK: module { 18# CHECK: pdl.pattern @operations : benefit(1) { 19# CHECK: %0 = attribute 20# CHECK: %1 = type 21# CHECK: %2 = operation {"attr" = %0} -> (%1 : !pdl.type) 22# CHECK: %3 = result 0 of %2 23# CHECK: %4 = operand 24# CHECK: %5 = operation(%3, %4 : !pdl.value, !pdl.value) 25# CHECK: rewrite %5 with "rewriter" 26# CHECK: } 27# CHECK: } 28@constructAndPrintInModule 29def test_operations(): 30 pattern = PatternOp(1, "operations") 31 with InsertionPoint(pattern.body): 32 attr = AttributeOp() 33 ty = TypeOp() 34 op0 = OperationOp(attributes={"attr": attr}, types=[ty]) 35 op0_result = ResultOp(op0, 0) 36 input = OperandOp() 37 root = OperationOp(args=[op0_result, input]) 38 RewriteOp(root, "rewriter") 39 40 41# CHECK: module { 42# CHECK: pdl.pattern @rewrite_with_args : benefit(1) { 43# CHECK: %0 = operand 44# CHECK: %1 = operation(%0 : !pdl.value) 45# CHECK: rewrite %1 with "rewriter"(%0 : !pdl.value) 46# CHECK: } 47# CHECK: } 48@constructAndPrintInModule 49def test_rewrite_with_args(): 50 pattern = PatternOp(1, "rewrite_with_args") 51 with InsertionPoint(pattern.body): 52 input = OperandOp() 53 root = OperationOp(args=[input]) 54 RewriteOp(root, "rewriter", args=[input]) 55 56 57# CHECK: module { 58# CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) { 59# CHECK: %0 = operand 60# CHECK: %1 = operand 61# CHECK: %2 = type 62# CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type) 63# CHECK: %4 = result 0 of %3 64# CHECK: %5 = operation(%4 : !pdl.value) 65# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) 66# CHECK: %7 = result 0 of %6 67# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) 68# CHECK: rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation) 69# CHECK: } 70# CHECK: } 71@constructAndPrintInModule 72def test_rewrite_multi_root_optimal(): 73 pattern = PatternOp(1, "rewrite_multi_root_optimal") 74 with InsertionPoint(pattern.body): 75 input1 = OperandOp() 76 input2 = OperandOp() 77 ty = TypeOp() 78 op1 = OperationOp(args=[input1], types=[ty]) 79 val1 = ResultOp(op1, 0) 80 root1 = OperationOp(args=[val1]) 81 op2 = OperationOp(args=[input2], types=[ty]) 82 val2 = ResultOp(op2, 0) 83 root2 = OperationOp(args=[val1, val2]) 84 RewriteOp(name="rewriter", args=[root1, root2]) 85 86 87# CHECK: module { 88# CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) { 89# CHECK: %0 = operand 90# CHECK: %1 = operand 91# CHECK: %2 = type 92# CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type) 93# CHECK: %4 = result 0 of %3 94# CHECK: %5 = operation(%4 : !pdl.value) 95# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) 96# CHECK: %7 = result 0 of %6 97# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) 98# CHECK: rewrite %5 with "rewriter"(%8 : !pdl.operation) 99# CHECK: } 100# CHECK: } 101@constructAndPrintInModule 102def test_rewrite_multi_root_forced(): 103 pattern = PatternOp(1, "rewrite_multi_root_forced") 104 with InsertionPoint(pattern.body): 105 input1 = OperandOp() 106 input2 = OperandOp() 107 ty = TypeOp() 108 op1 = OperationOp(args=[input1], types=[ty]) 109 val1 = ResultOp(op1, 0) 110 root1 = OperationOp(args=[val1]) 111 op2 = OperationOp(args=[input2], types=[ty]) 112 val2 = ResultOp(op2, 0) 113 root2 = OperationOp(args=[val1, val2]) 114 RewriteOp(root1, name="rewriter", args=[root2]) 115 116 117# CHECK: module { 118# CHECK: pdl.pattern @rewrite_add_body : benefit(1) { 119# CHECK: %0 = type : i32 120# CHECK: %1 = type 121# CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type) 122# CHECK: rewrite %2 { 123# CHECK: %3 = type 124# CHECK: %4 = operation "foo.op" -> (%0, %3 : !pdl.type, !pdl.type) 125# CHECK: replace %2 with %4 126# CHECK: } 127# CHECK: } 128# CHECK: } 129@constructAndPrintInModule 130def test_rewrite_add_body(): 131 pattern = PatternOp(1, "rewrite_add_body") 132 with InsertionPoint(pattern.body): 133 ty1 = TypeOp(IntegerType.get_signless(32)) 134 ty2 = TypeOp() 135 root = OperationOp(types=[ty1, ty2]) 136 rewrite = RewriteOp(root) 137 with InsertionPoint(rewrite.add_body()): 138 ty3 = TypeOp() 139 newOp = OperationOp(name="foo.op", types=[ty1, ty3]) 140 ReplaceOp(root, with_op=newOp) 141 142 143# CHECK: module { 144# CHECK: pdl.pattern @rewrite_type : benefit(1) { 145# CHECK: %0 = type : i32 146# CHECK: %1 = type 147# CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type) 148# CHECK: rewrite %2 { 149# CHECK: %3 = operation "foo.op" -> (%0, %1 : !pdl.type, !pdl.type) 150# CHECK: } 151# CHECK: } 152# CHECK: } 153@constructAndPrintInModule 154def test_rewrite_type(): 155 pattern = PatternOp(1, "rewrite_type") 156 with InsertionPoint(pattern.body): 157 ty1 = TypeOp(IntegerType.get_signless(32)) 158 ty2 = TypeOp() 159 root = OperationOp(types=[ty1, ty2]) 160 rewrite = RewriteOp(root) 161 with InsertionPoint(rewrite.add_body()): 162 newOp = OperationOp(name="foo.op", types=[ty1, ty2]) 163 164 165# CHECK: module { 166# CHECK: pdl.pattern @rewrite_types : benefit(1) { 167# CHECK: %0 = types 168# CHECK: %1 = operation -> (%0 : !pdl.range<type>) 169# CHECK: rewrite %1 { 170# CHECK: %2 = types : [i32, i64] 171# CHECK: %3 = operation "foo.op" -> (%0, %2 : !pdl.range<type>, !pdl.range<type>) 172# CHECK: } 173# CHECK: } 174# CHECK: } 175@constructAndPrintInModule 176def test_rewrite_types(): 177 pattern = PatternOp(1, "rewrite_types") 178 with InsertionPoint(pattern.body): 179 types = TypesOp() 180 root = OperationOp(types=[types]) 181 rewrite = RewriteOp(root) 182 with InsertionPoint(rewrite.add_body()): 183 otherTypes = TypesOp( 184 [IntegerType.get_signless(32), IntegerType.get_signless(64)] 185 ) 186 newOp = OperationOp(name="foo.op", types=[types, otherTypes]) 187 188 189# CHECK: module { 190# CHECK: pdl.pattern @rewrite_operands : benefit(1) { 191# CHECK: %0 = types 192# CHECK: %1 = operands : %0 193# CHECK: %2 = operation(%1 : !pdl.range<value>) 194# CHECK: rewrite %2 { 195# CHECK: %3 = operation "foo.op" -> (%0 : !pdl.range<type>) 196# CHECK: } 197# CHECK: } 198# CHECK: } 199@constructAndPrintInModule 200def test_rewrite_operands(): 201 pattern = PatternOp(1, "rewrite_operands") 202 with InsertionPoint(pattern.body): 203 types = TypesOp() 204 operands = OperandsOp(types) 205 root = OperationOp(args=[operands]) 206 rewrite = RewriteOp(root) 207 with InsertionPoint(rewrite.add_body()): 208 newOp = OperationOp(name="foo.op", types=[types]) 209 210 211# CHECK: module { 212# CHECK: pdl.pattern @native_rewrite : benefit(1) { 213# CHECK: %0 = operation 214# CHECK: rewrite %0 { 215# CHECK: apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation) 216# CHECK: } 217# CHECK: } 218# CHECK: } 219@constructAndPrintInModule 220def test_native_rewrite(): 221 pattern = PatternOp(1, "native_rewrite") 222 with InsertionPoint(pattern.body): 223 root = OperationOp() 224 rewrite = RewriteOp(root) 225 with InsertionPoint(rewrite.add_body()): 226 ApplyNativeRewriteOp([], "NativeRewrite", args=[root]) 227 228 229# CHECK: module { 230# CHECK: pdl.pattern @attribute_with_value : benefit(1) { 231# CHECK: %0 = operation 232# CHECK: rewrite %0 { 233# CHECK: %1 = attribute = "value" 234# CHECK: apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute) 235# CHECK: } 236# CHECK: } 237# CHECK: } 238@constructAndPrintInModule 239def test_attribute_with_value(): 240 pattern = PatternOp(1, "attribute_with_value") 241 with InsertionPoint(pattern.body): 242 root = OperationOp() 243 rewrite = RewriteOp(root) 244 with InsertionPoint(rewrite.add_body()): 245 attr = AttributeOp(value=Attribute.parse('"value"')) 246 ApplyNativeRewriteOp([], "NativeRewrite", args=[attr]) 247 248 249# CHECK: module { 250# CHECK: pdl.pattern @erase : benefit(1) { 251# CHECK: %0 = operation 252# CHECK: rewrite %0 { 253# CHECK: erase %0 254# CHECK: } 255# CHECK: } 256# CHECK: } 257@constructAndPrintInModule 258def test_erase(): 259 pattern = PatternOp(1, "erase") 260 with InsertionPoint(pattern.body): 261 root = OperationOp() 262 rewrite = RewriteOp(root) 263 with InsertionPoint(rewrite.add_body()): 264 EraseOp(root) 265 266 267# CHECK: module { 268# CHECK: pdl.pattern @operation_results : benefit(1) { 269# CHECK: %0 = types 270# CHECK: %1 = operation -> (%0 : !pdl.range<type>) 271# CHECK: %2 = results of %1 272# CHECK: %3 = operation(%2 : !pdl.range<value>) 273# CHECK: rewrite %3 with "rewriter" 274# CHECK: } 275# CHECK: } 276@constructAndPrintInModule 277def test_operation_results(): 278 valueRange = RangeType.get(ValueType.get()) 279 pattern = PatternOp(1, "operation_results") 280 with InsertionPoint(pattern.body): 281 types = TypesOp() 282 inputOp = OperationOp(types=[types]) 283 results = ResultsOp(valueRange, inputOp) 284 root = OperationOp(args=[results]) 285 RewriteOp(root, name="rewriter") 286 287 288# CHECK: module { 289# CHECK: pdl.pattern : benefit(1) { 290# CHECK: %0 = type 291# CHECK: apply_native_constraint "typeConstraint"(%0 : !pdl.type) 292# CHECK: %1 = operation -> (%0 : !pdl.type) 293# CHECK: rewrite %1 with "rewrite" 294# CHECK: } 295# CHECK: } 296@constructAndPrintInModule 297def test_apply_native_constraint(): 298 pattern = PatternOp(1) 299 with InsertionPoint(pattern.body): 300 resultType = TypeOp() 301 ApplyNativeConstraintOp([], "typeConstraint", args=[resultType]) 302 root = OperationOp(types=[resultType]) 303 RewriteOp(root, name="rewrite") 304