xref: /llvm-project/mlir/test/python/dialects/transform_structured_ext.py (revision 579ced4f8266b273d15b2801067a828151a222ef)
1# RUN: %PYTHON %s | FileCheck %s
2
3import functools
4from typing import Callable
5
6from mlir.ir import *
7from mlir.dialects import transform
8from mlir.dialects import pdl
9from mlir.dialects.transform import structured
10from mlir.dialects.transform import pdl as transform_pdl
11from mlir.dialects.transform.extras import constant_param
12
13
14def run(f):
15    with Context(), Location.unknown():
16        module = Module.create()
17        with InsertionPoint(module.body):
18            print("\nTEST:", f.__name__)
19            f()
20        module.operation.verify()
21        print(module)
22    return f
23
24
25def create_sequence(func: Callable) -> Callable:
26    @functools.wraps(func)
27    def decorated() -> None:
28        sequence = transform.SequenceOp(
29            transform.FailurePropagationMode.Propagate,
30            [],
31            transform.AnyOpType.get(),
32        )
33        with InsertionPoint(sequence.body):
34            func(sequence.bodyTarget)
35            transform.YieldOp()
36
37    return decorated
38
39
40@run
41@create_sequence
42def testBufferizeToAllocationOpCompact(target):
43    structured.BufferizeToAllocationOp(target)
44    # CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact
45    # CHECK: transform.sequence
46    # CHECK: transform.structured.bufferize_to_allocation
47
48
49@run
50@create_sequence
51def testBufferizeToAllocationOpArgs(target):
52    structured.BufferizeToAllocationOp(
53        target,
54        memory_space=3,
55        memcpy_op="memref.copy",
56        alloc_op="memref.alloca",
57        bufferize_destination_only=True,
58    )
59    # CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs
60    # CHECK: transform.sequence
61    # CHECK: transform.structured.bufferize_to_allocation
62    # CHECK-SAME: alloc_op = "memref.alloca"
63    # CHECK-SAME: bufferize_destination_only
64    # CHECK-SAME: memcpy_op = "memref.copy"
65    # CHECK-SAME: memory_space = 3
66
67
68@run
69@create_sequence
70def testDecompose(target):
71    structured.DecomposeOp(target)
72    # CHECK-LABEL: TEST: testDecompose
73    # CHECK: transform.sequence
74    # CHECK: transform.structured.decompose
75
76
77@run
78@create_sequence
79def testFuseIntoContainingOpTypes(target):
80    fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
81    containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
82    structured.FuseIntoContainingOp(
83        transform.OperationType.get("test.dummy"),
84        transform.OperationType.get("test.dummy"),
85        fused,
86        containing,
87    )
88    # CHECK-LABEL: TEST: testFuseIntoContainingOpTypes
89    # CHECK: = transform.structured.fuse_into_containing_op
90    # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">)
91
92
93@run
94@create_sequence
95def testFuseIntoContainingOpCompact(target):
96    fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
97    containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
98    structured.FuseIntoContainingOp(fused, containing)
99    # CHECK-LABEL: TEST: testFuseIntoContainingOpCompact
100    # CHECK: = transform.structured.fuse_into_containing_op
101    # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
102
103
104@run
105@create_sequence
106def testFuseOpCompact(target):
107    structured.FuseOp(
108        target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True
109    )
110    # CHECK-LABEL: TEST: testFuseOpCompact
111    # CHECK: transform.sequence
112    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
113    # CHECK-SAME: interchange [0, 1] apply_cleanup = true
114    # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
115
116
117@run
118@create_sequence
119def testFuseOpNoArg(target):
120    structured.FuseOp(target)
121    # CHECK-LABEL: TEST: testFuseOpNoArg
122    # CHECK: transform.sequence
123    # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
124    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
125
126
127@run
128@create_sequence
129def testFuseOpAttributes(target):
130    attr = DenseI64ArrayAttr.get([4, 8])
131    ichange = DenseI64ArrayAttr.get([0, 1])
132    structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
133    # CHECK-LABEL: TEST: testFuseOpAttributes
134    # CHECK: transform.sequence
135    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
136    # CHECK-SAME: interchange [0, 1]
137    # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
138
139
140@run
141@create_sequence
142def testGeneralize(target):
143    structured.GeneralizeOp(target)
144    # CHECK-LABEL: TEST: testGeneralize
145    # CHECK: transform.sequence
146    # CHECK: transform.structured.generalize
147
148
149@run
150@create_sequence
151def testInterchange(target):
152    structured.InterchangeOp(target, iterator_interchange=[1, 0])
153    # CHECK-LABEL: TEST: testInterchange
154    # CHECK: transform.sequence
155    # CHECK: transform.structured.interchange
156    # CHECK: iterator_interchange = [1, 0]
157
158
159@run
160@create_sequence
161def testMapCopyToThreadsOpCompact(target):
162    structured.MapCopyToThreadsOp(
163        target, total_num_threads=32, desired_bit_alignment=128
164    )
165    # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
166    # CHECK: = transform.structured.gpu.map_copy_to_threads
167    # CHECK-SAME: total_num_threads = 32
168    # CHECK-SAME: desired_bit_alignment = 128
169    # CHECK-SAME:  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
170
171
172@run
173@create_sequence
174def testMapCopyToThreadsOpTypes(target):
175    structured.MapCopyToThreadsOp(
176        transform.OperationType.get("test.opA"),
177        transform.OperationType.get("test.opB"),
178        target,
179        total_num_threads=32,
180        desired_bit_alignment=128,
181    )
182    # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
183    # CHECK: = transform.structured.gpu.map_copy_to_threads
184    # CHECK-SAME: total_num_threads = 32
185    # CHECK-SAME: desired_bit_alignment = 128
186    # CHECK-SAME:  (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">)
187
188
189@run
190@create_sequence
191def testMatchOpNamesString(target):
192    structured.MatchOp.match_op_names(target, "test.dummy")
193    # CHECK-LABEL: TEST: testMatchOpNamesString
194    # CHECK: transform.structured.match ops
195    # CHECK-SAME: ["test.dummy"]
196    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
197
198
199@run
200@create_sequence
201def testMatchOpNamesList(target):
202    structured.MatchOp.match_op_names(target, ["test.dummy"])
203    # CHECK-LABEL: TEST: testMatchOpNamesList
204    # CHECK: transform.structured.match ops
205    # CHECK-SAME: ["test.dummy"]
206    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
207
208
209@run
210@create_sequence
211def testVectorizeNoArgs(target):
212    structured.VectorizeOp(target)
213    # CHECK-LABEL: TEST: testVectorizeNoArgs
214    # CHECK: transform.sequence
215    # CHECK: transform.structured.vectorize
216    # CHECK-NOT:     vector_sizes
217
218
219@run
220@create_sequence
221def testVectorizeStatic(target):
222    structured.VectorizeOp(target, [16, 4])
223    # CHECK-LABEL: TEST: testVectorizeStatic
224    # CHECK: transform.sequence
225    # CHECK: transform.structured.vectorize
226    # CHECK-SAME:     vector_sizes [16, 4]
227
228
229@run
230@create_sequence
231def testVectorizeArray(target):
232    sizes = Attribute.parse("[16, 4]")
233    structured.VectorizeOp(target, sizes)
234    # CHECK-LABEL: TEST: testVectorizeArray
235    # CHECK: transform.sequence
236    # CHECK: transform.structured.vectorize
237    # CHECK-SAME:     vector_sizes [16, 4]
238
239
240@run
241@create_sequence
242def testVectorizeMixed(target):
243    sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
244    sz2 = Attribute.parse("4")
245    structured.VectorizeOp(target, [sz1, sz2])
246    # CHECK-LABEL: TEST: testVectorizeMixed
247    # CHECK: transform.sequence
248    # CHECK: %[[V0:.*]] = transform.structured.match
249    # CHECK: transform.structured.vectorize
250    # CHECK-SAME:     vector_sizes [%[[V0]], 4]
251
252
253@run
254@create_sequence
255def testVectorizeEmpty(target):
256    structured.VectorizeOp(target, [])
257    # CHECK-LABEL: TEST: testVectorizeEmpty
258    # CHECK: transform.sequence
259    # CHECK: transform.structured.vectorize
260    # CHECK-NOT:     vector_sizes
261
262
263@run
264@create_sequence
265def testVectorizeScalable(target):
266    sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
267    sz2 = Attribute.parse("4")
268    structured.VectorizeOp(target, [16, [sz1], [sz2], [8]])
269    # CHECK-LABEL: TEST: testVectorizeScalable
270    # CHECK: transform.sequence
271    # CHECK-DAG: %[[V0:.*]] = transform.structured.match
272    # CHECK-DAG: transform.structured.vectorize
273    # CHECK-SAME:     vector_sizes [16, [%[[V0]]], [4], [8]]
274
275
276@run
277@create_sequence
278def testVectorizeArgs(target):
279    structured.VectorizeOp(target, [16, 4], vectorize_nd_extract=True)
280    # CHECK-LABEL: TEST: testVectorizeArgs
281    # CHECK: transform.sequence
282    # CHECK: transform.structured.vectorize
283    # CHECK-SAME: vectorize_nd_extract
284
285
286@run
287@create_sequence
288def testMatchOpNamesTyped(target):
289    structured.MatchOp.match_op_names(
290        transform.OperationType.get("test.dummy"),
291        target,
292        ["test.dummy"],
293    )
294    # CHECK-LABEL: TEST: testMatchOpNamesTyped
295    # CHECK: transform.structured.match ops
296    # CHECK-SAME: ["test.dummy"]
297    # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">
298
299
300@run
301@create_sequence
302def testMultitileSizesCompact(target):
303    structured.MultiTileSizesOp(
304        transform.AnyOpType.get(), target, dimension=1, target_size=42
305    )
306    # CHECK-LABEL: TEST: testMultitileSizes
307    # CHECK: transform.sequence
308    # CHECK-NOT: divisor
309    # CHECK: transform.structured.multitile_sizes
310    # CHECK-NOT: divisor
311    # CHECK-DAG: dimension = 1
312    # CHECK-NOT: divisor
313    # CHECK-DAG: target_size = 42
314    # CHECK-NOT: divisor
315
316
317@run
318@create_sequence
319def testMultitileSizesAllArgs(target):
320    structured.MultiTileSizesOp(
321        transform.AnyOpType.get(),
322        target,
323        dimension=1,
324        target_size=42,
325        divisor=2,
326    )
327    # CHECK-LABEL: TEST: testMultitileSizes
328    # CHECK: transform.sequence
329    # CHECK: transform.structured.multitile_sizes
330    # CHECK-DAG: dimension = 1
331    # CHECK-DAG: divisor = 2
332    # CHECK-DAG: target_size = 42
333
334
335@run
336@create_sequence
337def testPadOpNoArgs(target):
338    structured.PadOp(target)
339    # CHECK-LABEL: TEST: testPadOpNoArgs
340    # CHECK: transform.sequence
341    # CHECK: transform.structured.pad
342    # CHECK-NOT: copy_back_op
343    # CHECK-NOT: nofold_flags
344    # CHECK-NOT: pad_to_multiple_of
345    # CHECK-NOT: padding_dimensions
346    # CHECK-NOT: padding_values
347    # CHECK-NOT: transpose_paddings
348
349
350@run
351@create_sequence
352def testPadOpArgs(target):
353    structured.PadOp(
354        target,
355        pad_to_multiple_of=[128],
356        padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
357        padding_dimensions=Attribute.parse("[1]"),
358        nofold_flags=[0],
359        transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
360        copy_back_op="linalg.copy",
361    )
362    # CHECK-LABEL: TEST: testPadOpArgs
363    # CHECK: transform.sequence
364    # CHECK: transform.structured.pad
365    # CHECK-DAG: pad_to_multiple_of [128]
366    # CHECK-DAG: copy_back_op = "linalg.copy"
367    # CHECK-DAG: nofold_flags = [0]
368    # CHECK-DAG: padding_dimensions = [1]
369    # CHECK-DAG: padding_values = [4.200000e+01 : f32, "0"]
370    # CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]]
371
372
373@run
374@create_sequence
375def testPadOpArgsParam(target):
376    structured.PadOp(
377        target,
378        pad_to_multiple_of=[constant_param(128), Attribute.parse("2"), 10],
379        padding_dimensions=Attribute.parse("[0, 1, 2]"),
380    )
381    # CHECK-LABEL: TEST: testPadOpArgsParam
382    # CHECK: transform.sequence
383    # CHECK-DAG: %[[P:.*]] = transform.param.constant 128
384    # CHECK: transform.structured.pad
385    # CHECK-DAG: pad_to_multiple_of [%[[P]], 2, 10]
386    # CHECK-DAG: padding_dimensions = [0, 1, 2]
387
388
389@run
390@create_sequence
391def testScalarize(target):
392    structured.ScalarizeOp(target)
393    # CHECK-LABEL: TEST: testScalarize
394    # CHECK: transform.structured.scalarize
395
396
397@run
398@create_sequence
399def testSplit(target):
400    handle = structured.SplitOp(target, dimension=1, chunk_sizes=42)
401    split = transform.SplitHandleOp(
402        [transform.AnyOpType.get(), transform.AnyOpType.get()], handle
403    )
404    structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
405    # CHECK-LABEL: TEST: testSplit
406    # CHECK: %[[G:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
407    # CHECK: %[[F:.+]]:2 = split_handle %[[G]]
408    # CHECK: transform.structured.split %[[F]]#0 after %[[F]]#1 {dimension = 3
409
410
411@run
412@create_sequence
413def testTileCompact(target):
414    structured.TileUsingForOp(target, sizes=[4, 8], interchange=[0, 1])
415    # CHECK-LABEL: TEST: testTileCompact
416    # CHECK: transform.sequence
417    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 8]
418    # CHECK: interchange = [0, 1]
419
420
421@run
422@create_sequence
423def testTileAttributes(target):
424    attr = DenseI64ArrayAttr.get([4, 8])
425    ichange = DenseI64ArrayAttr.get([0, 1])
426    structured.TileUsingForOp(target, sizes=attr, interchange=ichange)
427    # CHECK-LABEL: TEST: testTileAttributes
428    # CHECK: transform.sequence
429    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 8]
430    # CHECK: interchange = [0, 1]
431
432
433@run
434@create_sequence
435def testTileZero(target):
436    structured.TileUsingForOp(target, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
437    # CHECK-LABEL: TEST: testTileZero
438    # CHECK: transform.sequence
439    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 0, 2, 0]
440    # CHECK: interchange = [0, 1, 2, 3]
441
442
443@run
444def testTileDynamic():
445    with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
446    with InsertionPoint(with_pdl.body):
447        sequence = transform.SequenceOp(
448            transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
449        )
450        with InsertionPoint(sequence.body):
451            m1 = transform_pdl.PDLMatchOp(
452                pdl.OperationType.get(), sequence.bodyTarget, "first"
453            )
454            m2 = transform_pdl.PDLMatchOp(
455                pdl.OperationType.get(), sequence.bodyTarget, "second"
456            )
457            structured.TileUsingForOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
458            transform.YieldOp()
459    # CHECK-LABEL: TEST: testTileDynamic
460    # CHECK: %[[FIRST:.+]] = pdl_match
461    # CHECK: %[[SECOND:.+]] = pdl_match
462    # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile_using_for %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
463
464
465@run
466@create_sequence
467def testTileExplicitLoopTypeSingle(target):
468    structured.TileUsingForOp(
469        transform.OperationType.get("scf.for"), target, sizes=[2, 3, 4]
470    )
471    # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
472    # CHECK: = transform.structured.tile_using_for %{{.*}} : (!{{.*}}) ->
473    # CHECK-COUNT-3: !transform.op<"scf.for">
474
475
476@run
477@create_sequence
478def testTileExplicitLoopTypeAll(target):
479    types = [
480        transform.OperationType.get(x)
481        for x in ["scf.for", "scf.parallel", "scf.forall"]
482    ]
483    structured.TileUsingForOp(types, target, sizes=[2, 3, 4])
484    # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
485    # CHECK: = transform.structured.tile
486    # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
487    # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
488
489
490@run
491@create_sequence
492def testTileScalable(target):
493    structured.TileUsingForOp(
494        target,
495        sizes=[4, [2]],
496    )
497    # CHECK-LABEL: TEST: testTileScalable
498    # CHECK: transform.sequence
499    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, [2]]
500
501
502@run
503@create_sequence
504def testTileToForallCompact(target):
505    matmul = transform.CastOp(transform.OperationType.get("linalg.matmul"), target)
506    structured.TileUsingForallOp(matmul, num_threads=[2, 3, 4])
507    # CHECK-LABEL: TEST: testTileToForallCompact
508    # CHECK: = transform.structured.tile_using_forall
509    # CHECK-SAME: num_threads [2, 3, 4]
510    # CHECK-SAME: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
511
512
513@run
514@create_sequence
515def testTileToForallLoopsAndTileOpTypes(target):
516    structured.TileUsingForallOp(
517        transform.OperationType.get("scf.forall"),  # loops_type
518        transform.OperationType.get("linalg.matmul"),  # tiled_op_type
519        target,
520        num_threads=[2, 3, 4],
521    )
522    # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes
523    # CHECK: = transform.structured.tile_using_forall
524    # CHECK-SAME: num_threads [2, 3, 4]
525    # CHECK-SAME: (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.matmul">)
526
527
528@run
529@create_sequence
530def testTileToForallTileSizes(target):
531    structured.TileUsingForallOp(target, tile_sizes=[2, 3, 4])
532    # CHECK-LABEL: TEST: testTileToForallTileSizes
533    # CHECK: = transform.structured.tile_using_forall
534    # CHECK-SAME: tile_sizes [2, 3, 4]
535
536
537@run
538@create_sequence
539def testTileToForallMixedDynamic(target):
540    n = structured.MatchOp.match_op_names(target, ["test.dummy"])
541    structured.TileUsingForallOp(target, num_threads=[n, 3, 4])
542    # CHECK-LABEL: TEST: testTileToForallMixedDynamic
543    # CHECK: = transform.structured.tile_using_forall
544    # CHECK-SAME: num_threads [%{{.*}}, 3, 4] : (!transform.any_op, !transform.any_op)
545
546
547@run
548@create_sequence
549def testTileToForallPackedDynamic(target):
550    n = structured.MatchOp.match_op_names(target, ["test.dummy"])
551    structured.TileUsingForallOp(target, num_threads=n)
552    # CHECK-LABEL: TEST: testTileToForallPackedDynamic
553    # CHECK: = transform.structured.tile_using_forall
554    # CHECK-SAME: num_threads *(%0) : (!transform.any_op, !transform.any_op)
555
556
557@run
558@create_sequence
559def testTileToForallMapping(target):
560    mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
561    structured.TileUsingForallOp(target, num_threads=[2, 3], mapping=mapping)
562    # CHECK-LABEL: TEST: testTileToForallMapping
563    # CHECK: = transform.structured.tile_using_forall
564    # CHECK-SAME: mapping = [#gpu.thread<y>, #gpu.thread<x>]
565
566
567@run
568@create_sequence
569def testVectorizeChildrenAndApplyPatternsAllAttrs(target):
570    structured.VectorizeChildrenAndApplyPatternsOp(
571        target,
572        disable_multi_reduction_to_contract_patterns=True,
573        disable_transfer_permutation_map_lowering_patterns=True,
574        vectorize_nd_extract=True,
575        vectorize_padding=True,
576    )
577    # CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsAllAttrs
578    # CHECK: transform.sequence
579    # CHECK: = transform.structured.vectorize
580    # CHECK-SAME: disable_multi_reduction_to_contract_patterns
581    # CHECK-SAME: disable_transfer_permutation_map_lowering_patterns
582    # CHECK-SAME: vectorize_nd_extract
583    # CHECK-SAME: vectorize_padding
584
585
586@run
587@create_sequence
588def testVectorizeChildrenAndApplyPatternsNoAttrs(target):
589    structured.VectorizeChildrenAndApplyPatternsOp(
590        target,
591        disable_multi_reduction_to_contract_patterns=False,
592        disable_transfer_permutation_map_lowering_patterns=False,
593        vectorize_nd_extract=False,
594        vectorize_padding=False,
595    )
596    # CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsNoAttrs
597    # CHECK: transform.sequence
598    # CHECK: = transform.structured.vectorize
599    # CHECK-NOT: disable_multi_reduction_to_contract_patterns
600    # CHECK-NOT: disable_transfer_permutation_map_lowering_patterns
601    # CHECK-NOT: vectorize_nd_extract
602    # CHECK-NOT: vectorize_padding
603
604
605@run
606@create_sequence
607def testMatchInterfaceEnum(target):
608    names = ArrayAttr.get([StringAttr.get("test.dummy")])
609    result_type = transform.AnyOpType.get()
610    fused = structured.MatchOp.__base__(
611        result_type,
612        target,
613        ops=names,
614        interface=structured.MatchInterfaceEnum.LinalgOp,
615    )
616    # CHECK-LABEL: TEST: testMatchInterfaceEnum
617    # CHECK: transform.sequence
618    # CHECK: = transform.structured.match
619    # CHECK: interface{LinalgOp}
620
621
622@run
623@create_sequence
624def testMatchInterfaceEnumReplaceAttributeBuilder(target):
625    @register_attribute_builder("MatchInterfaceEnum", replace=True)
626    def match_interface_enum(x, context):
627        if x == "LinalgOp":
628            y = 0
629        elif x == "TilingInterface":
630            y = 1
631        return IntegerAttr.get(IntegerType.get_signless(32, context=context), y)
632
633    names = ArrayAttr.get([StringAttr.get("test.dummy")])
634    result_type = transform.AnyOpType.get()
635    fused = structured.MatchOp.__base__(
636        result_type,
637        target,
638        ops=names,
639        interface="TilingInterface",
640    )
641    # CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder
642    # CHECK: transform.sequence
643    # CHECK: = transform.structured.match
644    # CHECK: interface{TilingInterface}
645