1# RUN: %PYTHON %s | FileCheck %s 2 3import functools 4from typing import Callable 5 6from mlir.ir import * 7from mlir.dialects import transform 8from mlir.dialects import pdl 9from mlir.dialects.transform import structured 10from mlir.dialects.transform import pdl as transform_pdl 11from mlir.dialects.transform.extras import constant_param 12 13 14def run(f): 15 with Context(), Location.unknown(): 16 module = Module.create() 17 with InsertionPoint(module.body): 18 print("\nTEST:", f.__name__) 19 f() 20 module.operation.verify() 21 print(module) 22 return f 23 24 25def create_sequence(func: Callable) -> Callable: 26 @functools.wraps(func) 27 def decorated() -> None: 28 sequence = transform.SequenceOp( 29 transform.FailurePropagationMode.Propagate, 30 [], 31 transform.AnyOpType.get(), 32 ) 33 with InsertionPoint(sequence.body): 34 func(sequence.bodyTarget) 35 transform.YieldOp() 36 37 return decorated 38 39 40@run 41@create_sequence 42def testBufferizeToAllocationOpCompact(target): 43 structured.BufferizeToAllocationOp(target) 44 # CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact 45 # CHECK: transform.sequence 46 # CHECK: transform.structured.bufferize_to_allocation 47 48 49@run 50@create_sequence 51def testBufferizeToAllocationOpArgs(target): 52 structured.BufferizeToAllocationOp( 53 target, 54 memory_space=3, 55 memcpy_op="memref.copy", 56 alloc_op="memref.alloca", 57 bufferize_destination_only=True, 58 ) 59 # CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs 60 # CHECK: transform.sequence 61 # CHECK: transform.structured.bufferize_to_allocation 62 # CHECK-SAME: alloc_op = "memref.alloca" 63 # CHECK-SAME: bufferize_destination_only 64 # CHECK-SAME: memcpy_op = "memref.copy" 65 # CHECK-SAME: memory_space = 3 66 67 68@run 69@create_sequence 70def testDecompose(target): 71 structured.DecomposeOp(target) 72 # CHECK-LABEL: TEST: testDecompose 73 # CHECK: transform.sequence 74 # CHECK: transform.structured.decompose 75 76 77@run 78@create_sequence 79def testFuseIntoContainingOpTypes(target): 80 fused = structured.MatchOp.match_op_names(target, ["test.dummy"]) 81 containing = structured.MatchOp.match_op_names(target, ["test.dummy"]) 82 structured.FuseIntoContainingOp( 83 transform.OperationType.get("test.dummy"), 84 transform.OperationType.get("test.dummy"), 85 fused, 86 containing, 87 ) 88 # CHECK-LABEL: TEST: testFuseIntoContainingOpTypes 89 # CHECK: = transform.structured.fuse_into_containing_op 90 # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">) 91 92 93@run 94@create_sequence 95def testFuseIntoContainingOpCompact(target): 96 fused = structured.MatchOp.match_op_names(target, ["test.dummy"]) 97 containing = structured.MatchOp.match_op_names(target, ["test.dummy"]) 98 structured.FuseIntoContainingOp(fused, containing) 99 # CHECK-LABEL: TEST: testFuseIntoContainingOpCompact 100 # CHECK: = transform.structured.fuse_into_containing_op 101 # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 102 103 104@run 105@create_sequence 106def testFuseOpCompact(target): 107 structured.FuseOp( 108 target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True 109 ) 110 # CHECK-LABEL: TEST: testFuseOpCompact 111 # CHECK: transform.sequence 112 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] 113 # CHECK-SAME: interchange [0, 1] apply_cleanup = true 114 # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 115 116 117@run 118@create_sequence 119def testFuseOpNoArg(target): 120 structured.FuseOp(target) 121 # CHECK-LABEL: TEST: testFuseOpNoArg 122 # CHECK: transform.sequence 123 # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} : 124 # CHECK-SAME: (!transform.any_op) -> !transform.any_op 125 126 127@run 128@create_sequence 129def testFuseOpAttributes(target): 130 attr = DenseI64ArrayAttr.get([4, 8]) 131 ichange = DenseI64ArrayAttr.get([0, 1]) 132 structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange) 133 # CHECK-LABEL: TEST: testFuseOpAttributes 134 # CHECK: transform.sequence 135 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] 136 # CHECK-SAME: interchange [0, 1] 137 # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 138 139 140@run 141@create_sequence 142def testGeneralize(target): 143 structured.GeneralizeOp(target) 144 # CHECK-LABEL: TEST: testGeneralize 145 # CHECK: transform.sequence 146 # CHECK: transform.structured.generalize 147 148 149@run 150@create_sequence 151def testInterchange(target): 152 structured.InterchangeOp(target, iterator_interchange=[1, 0]) 153 # CHECK-LABEL: TEST: testInterchange 154 # CHECK: transform.sequence 155 # CHECK: transform.structured.interchange 156 # CHECK: iterator_interchange = [1, 0] 157 158 159@run 160@create_sequence 161def testMapCopyToThreadsOpCompact(target): 162 structured.MapCopyToThreadsOp( 163 target, total_num_threads=32, desired_bit_alignment=128 164 ) 165 # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact 166 # CHECK: = transform.structured.gpu.map_copy_to_threads 167 # CHECK-SAME: total_num_threads = 32 168 # CHECK-SAME: desired_bit_alignment = 128 169 # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op) 170 171 172@run 173@create_sequence 174def testMapCopyToThreadsOpTypes(target): 175 structured.MapCopyToThreadsOp( 176 transform.OperationType.get("test.opA"), 177 transform.OperationType.get("test.opB"), 178 target, 179 total_num_threads=32, 180 desired_bit_alignment=128, 181 ) 182 # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes 183 # CHECK: = transform.structured.gpu.map_copy_to_threads 184 # CHECK-SAME: total_num_threads = 32 185 # CHECK-SAME: desired_bit_alignment = 128 186 # CHECK-SAME: (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">) 187 188 189@run 190@create_sequence 191def testMatchOpNamesString(target): 192 structured.MatchOp.match_op_names(target, "test.dummy") 193 # CHECK-LABEL: TEST: testMatchOpNamesString 194 # CHECK: transform.structured.match ops 195 # CHECK-SAME: ["test.dummy"] 196 # CHECK-SAME: (!transform.any_op) -> !transform.any_op 197 198 199@run 200@create_sequence 201def testMatchOpNamesList(target): 202 structured.MatchOp.match_op_names(target, ["test.dummy"]) 203 # CHECK-LABEL: TEST: testMatchOpNamesList 204 # CHECK: transform.structured.match ops 205 # CHECK-SAME: ["test.dummy"] 206 # CHECK-SAME: (!transform.any_op) -> !transform.any_op 207 208 209@run 210@create_sequence 211def testVectorizeNoArgs(target): 212 structured.VectorizeOp(target) 213 # CHECK-LABEL: TEST: testVectorizeNoArgs 214 # CHECK: transform.sequence 215 # CHECK: transform.structured.vectorize 216 # CHECK-NOT: vector_sizes 217 218 219@run 220@create_sequence 221def testVectorizeStatic(target): 222 structured.VectorizeOp(target, [16, 4]) 223 # CHECK-LABEL: TEST: testVectorizeStatic 224 # CHECK: transform.sequence 225 # CHECK: transform.structured.vectorize 226 # CHECK-SAME: vector_sizes [16, 4] 227 228 229@run 230@create_sequence 231def testVectorizeArray(target): 232 sizes = Attribute.parse("[16, 4]") 233 structured.VectorizeOp(target, sizes) 234 # CHECK-LABEL: TEST: testVectorizeArray 235 # CHECK: transform.sequence 236 # CHECK: transform.structured.vectorize 237 # CHECK-SAME: vector_sizes [16, 4] 238 239 240@run 241@create_sequence 242def testVectorizeMixed(target): 243 sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) 244 sz2 = Attribute.parse("4") 245 structured.VectorizeOp(target, [sz1, sz2]) 246 # CHECK-LABEL: TEST: testVectorizeMixed 247 # CHECK: transform.sequence 248 # CHECK: %[[V0:.*]] = transform.structured.match 249 # CHECK: transform.structured.vectorize 250 # CHECK-SAME: vector_sizes [%[[V0]], 4] 251 252 253@run 254@create_sequence 255def testVectorizeEmpty(target): 256 structured.VectorizeOp(target, []) 257 # CHECK-LABEL: TEST: testVectorizeEmpty 258 # CHECK: transform.sequence 259 # CHECK: transform.structured.vectorize 260 # CHECK-NOT: vector_sizes 261 262 263@run 264@create_sequence 265def testVectorizeScalable(target): 266 sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) 267 sz2 = Attribute.parse("4") 268 structured.VectorizeOp(target, [16, [sz1], [sz2], [8]]) 269 # CHECK-LABEL: TEST: testVectorizeScalable 270 # CHECK: transform.sequence 271 # CHECK-DAG: %[[V0:.*]] = transform.structured.match 272 # CHECK-DAG: transform.structured.vectorize 273 # CHECK-SAME: vector_sizes [16, [%[[V0]]], [4], [8]] 274 275 276@run 277@create_sequence 278def testVectorizeArgs(target): 279 structured.VectorizeOp(target, [16, 4], vectorize_nd_extract=True) 280 # CHECK-LABEL: TEST: testVectorizeArgs 281 # CHECK: transform.sequence 282 # CHECK: transform.structured.vectorize 283 # CHECK-SAME: vectorize_nd_extract 284 285 286@run 287@create_sequence 288def testMatchOpNamesTyped(target): 289 structured.MatchOp.match_op_names( 290 transform.OperationType.get("test.dummy"), 291 target, 292 ["test.dummy"], 293 ) 294 # CHECK-LABEL: TEST: testMatchOpNamesTyped 295 # CHECK: transform.structured.match ops 296 # CHECK-SAME: ["test.dummy"] 297 # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy"> 298 299 300@run 301@create_sequence 302def testMultitileSizesCompact(target): 303 structured.MultiTileSizesOp( 304 transform.AnyOpType.get(), target, dimension=1, target_size=42 305 ) 306 # CHECK-LABEL: TEST: testMultitileSizes 307 # CHECK: transform.sequence 308 # CHECK-NOT: divisor 309 # CHECK: transform.structured.multitile_sizes 310 # CHECK-NOT: divisor 311 # CHECK-DAG: dimension = 1 312 # CHECK-NOT: divisor 313 # CHECK-DAG: target_size = 42 314 # CHECK-NOT: divisor 315 316 317@run 318@create_sequence 319def testMultitileSizesAllArgs(target): 320 structured.MultiTileSizesOp( 321 transform.AnyOpType.get(), 322 target, 323 dimension=1, 324 target_size=42, 325 divisor=2, 326 ) 327 # CHECK-LABEL: TEST: testMultitileSizes 328 # CHECK: transform.sequence 329 # CHECK: transform.structured.multitile_sizes 330 # CHECK-DAG: dimension = 1 331 # CHECK-DAG: divisor = 2 332 # CHECK-DAG: target_size = 42 333 334 335@run 336@create_sequence 337def testPadOpNoArgs(target): 338 structured.PadOp(target) 339 # CHECK-LABEL: TEST: testPadOpNoArgs 340 # CHECK: transform.sequence 341 # CHECK: transform.structured.pad 342 # CHECK-NOT: copy_back_op 343 # CHECK-NOT: nofold_flags 344 # CHECK-NOT: pad_to_multiple_of 345 # CHECK-NOT: padding_dimensions 346 # CHECK-NOT: padding_values 347 # CHECK-NOT: transpose_paddings 348 349 350@run 351@create_sequence 352def testPadOpArgs(target): 353 structured.PadOp( 354 target, 355 pad_to_multiple_of=[128], 356 padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")], 357 padding_dimensions=Attribute.parse("[1]"), 358 nofold_flags=[0], 359 transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")], 360 copy_back_op="linalg.copy", 361 ) 362 # CHECK-LABEL: TEST: testPadOpArgs 363 # CHECK: transform.sequence 364 # CHECK: transform.structured.pad 365 # CHECK-DAG: pad_to_multiple_of [128] 366 # CHECK-DAG: copy_back_op = "linalg.copy" 367 # CHECK-DAG: nofold_flags = [0] 368 # CHECK-DAG: padding_dimensions = [1] 369 # CHECK-DAG: padding_values = [4.200000e+01 : f32, "0"] 370 # CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]] 371 372 373@run 374@create_sequence 375def testPadOpArgsParam(target): 376 structured.PadOp( 377 target, 378 pad_to_multiple_of=[constant_param(128), Attribute.parse("2"), 10], 379 padding_dimensions=Attribute.parse("[0, 1, 2]"), 380 ) 381 # CHECK-LABEL: TEST: testPadOpArgsParam 382 # CHECK: transform.sequence 383 # CHECK-DAG: %[[P:.*]] = transform.param.constant 128 384 # CHECK: transform.structured.pad 385 # CHECK-DAG: pad_to_multiple_of [%[[P]], 2, 10] 386 # CHECK-DAG: padding_dimensions = [0, 1, 2] 387 388 389@run 390@create_sequence 391def testScalarize(target): 392 structured.ScalarizeOp(target) 393 # CHECK-LABEL: TEST: testScalarize 394 # CHECK: transform.structured.scalarize 395 396 397@run 398@create_sequence 399def testSplit(target): 400 handle = structured.SplitOp(target, dimension=1, chunk_sizes=42) 401 split = transform.SplitHandleOp( 402 [transform.AnyOpType.get(), transform.AnyOpType.get()], handle 403 ) 404 structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1]) 405 # CHECK-LABEL: TEST: testSplit 406 # CHECK: %[[G:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 407 # CHECK: %[[F:.+]]:2 = split_handle %[[G]] 408 # CHECK: transform.structured.split %[[F]]#0 after %[[F]]#1 {dimension = 3 409 410 411@run 412@create_sequence 413def testTileCompact(target): 414 structured.TileUsingForOp(target, sizes=[4, 8], interchange=[0, 1]) 415 # CHECK-LABEL: TEST: testTileCompact 416 # CHECK: transform.sequence 417 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 8] 418 # CHECK: interchange = [0, 1] 419 420 421@run 422@create_sequence 423def testTileAttributes(target): 424 attr = DenseI64ArrayAttr.get([4, 8]) 425 ichange = DenseI64ArrayAttr.get([0, 1]) 426 structured.TileUsingForOp(target, sizes=attr, interchange=ichange) 427 # CHECK-LABEL: TEST: testTileAttributes 428 # CHECK: transform.sequence 429 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 8] 430 # CHECK: interchange = [0, 1] 431 432 433@run 434@create_sequence 435def testTileZero(target): 436 structured.TileUsingForOp(target, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) 437 # CHECK-LABEL: TEST: testTileZero 438 # CHECK: transform.sequence 439 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 0, 2, 0] 440 # CHECK: interchange = [0, 1, 2, 3] 441 442 443@run 444def testTileDynamic(): 445 with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get()) 446 with InsertionPoint(with_pdl.body): 447 sequence = transform.SequenceOp( 448 transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget 449 ) 450 with InsertionPoint(sequence.body): 451 m1 = transform_pdl.PDLMatchOp( 452 pdl.OperationType.get(), sequence.bodyTarget, "first" 453 ) 454 m2 = transform_pdl.PDLMatchOp( 455 pdl.OperationType.get(), sequence.bodyTarget, "second" 456 ) 457 structured.TileUsingForOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) 458 transform.YieldOp() 459 # CHECK-LABEL: TEST: testTileDynamic 460 # CHECK: %[[FIRST:.+]] = pdl_match 461 # CHECK: %[[SECOND:.+]] = pdl_match 462 # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile_using_for %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0] 463 464 465@run 466@create_sequence 467def testTileExplicitLoopTypeSingle(target): 468 structured.TileUsingForOp( 469 transform.OperationType.get("scf.for"), target, sizes=[2, 3, 4] 470 ) 471 # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle 472 # CHECK: = transform.structured.tile_using_for %{{.*}} : (!{{.*}}) -> 473 # CHECK-COUNT-3: !transform.op<"scf.for"> 474 475 476@run 477@create_sequence 478def testTileExplicitLoopTypeAll(target): 479 types = [ 480 transform.OperationType.get(x) 481 for x in ["scf.for", "scf.parallel", "scf.forall"] 482 ] 483 structured.TileUsingForOp(types, target, sizes=[2, 3, 4]) 484 # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll 485 # CHECK: = transform.structured.tile 486 # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, 487 # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall"> 488 489 490@run 491@create_sequence 492def testTileScalable(target): 493 structured.TileUsingForOp( 494 target, 495 sizes=[4, [2]], 496 ) 497 # CHECK-LABEL: TEST: testTileScalable 498 # CHECK: transform.sequence 499 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, [2]] 500 501 502@run 503@create_sequence 504def testTileToForallCompact(target): 505 matmul = transform.CastOp(transform.OperationType.get("linalg.matmul"), target) 506 structured.TileUsingForallOp(matmul, num_threads=[2, 3, 4]) 507 # CHECK-LABEL: TEST: testTileToForallCompact 508 # CHECK: = transform.structured.tile_using_forall 509 # CHECK-SAME: num_threads [2, 3, 4] 510 # CHECK-SAME: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op) 511 512 513@run 514@create_sequence 515def testTileToForallLoopsAndTileOpTypes(target): 516 structured.TileUsingForallOp( 517 transform.OperationType.get("scf.forall"), # loops_type 518 transform.OperationType.get("linalg.matmul"), # tiled_op_type 519 target, 520 num_threads=[2, 3, 4], 521 ) 522 # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes 523 # CHECK: = transform.structured.tile_using_forall 524 # CHECK-SAME: num_threads [2, 3, 4] 525 # CHECK-SAME: (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.matmul">) 526 527 528@run 529@create_sequence 530def testTileToForallTileSizes(target): 531 structured.TileUsingForallOp(target, tile_sizes=[2, 3, 4]) 532 # CHECK-LABEL: TEST: testTileToForallTileSizes 533 # CHECK: = transform.structured.tile_using_forall 534 # CHECK-SAME: tile_sizes [2, 3, 4] 535 536 537@run 538@create_sequence 539def testTileToForallMixedDynamic(target): 540 n = structured.MatchOp.match_op_names(target, ["test.dummy"]) 541 structured.TileUsingForallOp(target, num_threads=[n, 3, 4]) 542 # CHECK-LABEL: TEST: testTileToForallMixedDynamic 543 # CHECK: = transform.structured.tile_using_forall 544 # CHECK-SAME: num_threads [%{{.*}}, 3, 4] : (!transform.any_op, !transform.any_op) 545 546 547@run 548@create_sequence 549def testTileToForallPackedDynamic(target): 550 n = structured.MatchOp.match_op_names(target, ["test.dummy"]) 551 structured.TileUsingForallOp(target, num_threads=n) 552 # CHECK-LABEL: TEST: testTileToForallPackedDynamic 553 # CHECK: = transform.structured.tile_using_forall 554 # CHECK-SAME: num_threads *(%0) : (!transform.any_op, !transform.any_op) 555 556 557@run 558@create_sequence 559def testTileToForallMapping(target): 560 mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]") 561 structured.TileUsingForallOp(target, num_threads=[2, 3], mapping=mapping) 562 # CHECK-LABEL: TEST: testTileToForallMapping 563 # CHECK: = transform.structured.tile_using_forall 564 # CHECK-SAME: mapping = [#gpu.thread<y>, #gpu.thread<x>] 565 566 567@run 568@create_sequence 569def testVectorizeChildrenAndApplyPatternsAllAttrs(target): 570 structured.VectorizeChildrenAndApplyPatternsOp( 571 target, 572 disable_multi_reduction_to_contract_patterns=True, 573 disable_transfer_permutation_map_lowering_patterns=True, 574 vectorize_nd_extract=True, 575 vectorize_padding=True, 576 ) 577 # CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsAllAttrs 578 # CHECK: transform.sequence 579 # CHECK: = transform.structured.vectorize 580 # CHECK-SAME: disable_multi_reduction_to_contract_patterns 581 # CHECK-SAME: disable_transfer_permutation_map_lowering_patterns 582 # CHECK-SAME: vectorize_nd_extract 583 # CHECK-SAME: vectorize_padding 584 585 586@run 587@create_sequence 588def testVectorizeChildrenAndApplyPatternsNoAttrs(target): 589 structured.VectorizeChildrenAndApplyPatternsOp( 590 target, 591 disable_multi_reduction_to_contract_patterns=False, 592 disable_transfer_permutation_map_lowering_patterns=False, 593 vectorize_nd_extract=False, 594 vectorize_padding=False, 595 ) 596 # CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsNoAttrs 597 # CHECK: transform.sequence 598 # CHECK: = transform.structured.vectorize 599 # CHECK-NOT: disable_multi_reduction_to_contract_patterns 600 # CHECK-NOT: disable_transfer_permutation_map_lowering_patterns 601 # CHECK-NOT: vectorize_nd_extract 602 # CHECK-NOT: vectorize_padding 603 604 605@run 606@create_sequence 607def testMatchInterfaceEnum(target): 608 names = ArrayAttr.get([StringAttr.get("test.dummy")]) 609 result_type = transform.AnyOpType.get() 610 fused = structured.MatchOp.__base__( 611 result_type, 612 target, 613 ops=names, 614 interface=structured.MatchInterfaceEnum.LinalgOp, 615 ) 616 # CHECK-LABEL: TEST: testMatchInterfaceEnum 617 # CHECK: transform.sequence 618 # CHECK: = transform.structured.match 619 # CHECK: interface{LinalgOp} 620 621 622@run 623@create_sequence 624def testMatchInterfaceEnumReplaceAttributeBuilder(target): 625 @register_attribute_builder("MatchInterfaceEnum", replace=True) 626 def match_interface_enum(x, context): 627 if x == "LinalgOp": 628 y = 0 629 elif x == "TilingInterface": 630 y = 1 631 return IntegerAttr.get(IntegerType.get_signless(32, context=context), y) 632 633 names = ArrayAttr.get([StringAttr.get("test.dummy")]) 634 result_type = transform.AnyOpType.get() 635 fused = structured.MatchOp.__base__( 636 result_type, 637 target, 638 ops=names, 639 interface="TilingInterface", 640 ) 641 # CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder 642 # CHECK: transform.sequence 643 # CHECK: = transform.structured.match 644 # CHECK: interface{TilingInterface} 645