xref: /llvm-project/mlir/test/python/dialects/transform_vector_ext.py (revision f54dc7b3936f1bd751db710cfc2fec1652159a3f)
1d5171173SAlex Zinenko# RUN: %PYTHON %s | FileCheck %s
2d5171173SAlex Zinenko
3d5171173SAlex Zinenkofrom mlir.ir import *
4d5171173SAlex Zinenkofrom mlir.dialects import transform
5d5171173SAlex Zinenkofrom mlir.dialects.transform import vector
6d5171173SAlex Zinenko
7d5171173SAlex Zinenko
8d5171173SAlex Zinenkodef run_apply_patterns(f):
9d5171173SAlex Zinenko    with Context(), Location.unknown():
10d5171173SAlex Zinenko        module = Module.create()
11d5171173SAlex Zinenko        with InsertionPoint(module.body):
12d5171173SAlex Zinenko            sequence = transform.SequenceOp(
1392233062Smax                transform.FailurePropagationMode.Propagate,
14d5171173SAlex Zinenko                [],
15d5171173SAlex Zinenko                transform.AnyOpType.get(),
16d5171173SAlex Zinenko            )
17d5171173SAlex Zinenko            with InsertionPoint(sequence.body):
18d5171173SAlex Zinenko                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
19d5171173SAlex Zinenko                with InsertionPoint(apply.patterns):
20d5171173SAlex Zinenko                    f()
21d5171173SAlex Zinenko                transform.YieldOp()
22d5171173SAlex Zinenko        print("\nTEST:", f.__name__)
23d5171173SAlex Zinenko        print(module)
24d5171173SAlex Zinenko    return f
25d5171173SAlex Zinenko
26d5171173SAlex Zinenko
27d5171173SAlex Zinenko@run_apply_patterns
28d5171173SAlex Zinenkodef non_configurable_patterns():
29d5171173SAlex Zinenko    # CHECK-LABEL: TEST: non_configurable_patterns
30d5171173SAlex Zinenko    # CHECK: apply_patterns
31d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
32d5171173SAlex Zinenko    vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
33d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
34d5171173SAlex Zinenko    vector.ApplyRankReducingSubviewPatternsOp()
35d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
36d5171173SAlex Zinenko    vector.ApplyTransferPermutationPatternsOp()
37d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_broadcast
38d5171173SAlex Zinenko    vector.ApplyLowerBroadcastPatternsOp()
39d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_masks
40d5171173SAlex Zinenko    vector.ApplyLowerMasksPatternsOp()
41d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_masked_transfers
42d5171173SAlex Zinenko    vector.ApplyLowerMaskedTransfersPatternsOp()
43d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.materialize_masks
44d5171173SAlex Zinenko    vector.ApplyMaterializeMasksPatternsOp()
45d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_outerproduct
46d5171173SAlex Zinenko    vector.ApplyLowerOuterProductPatternsOp()
47d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_gather
48d5171173SAlex Zinenko    vector.ApplyLowerGatherPatternsOp()
49d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_scan
50d5171173SAlex Zinenko    vector.ApplyLowerScanPatternsOp()
51d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_shape_cast
52d5171173SAlex Zinenko    vector.ApplyLowerShapeCastPatternsOp()
53d5171173SAlex Zinenko
54d5171173SAlex Zinenko
55d5171173SAlex Zinenko@run_apply_patterns
56d5171173SAlex Zinenkodef configurable_patterns():
57d5171173SAlex Zinenko    # CHECK-LABEL: TEST: configurable_patterns
58d5171173SAlex Zinenko    # CHECK: apply_patterns
59d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_transfer
60d5171173SAlex Zinenko    # CHECK-SAME: max_transfer_rank = 4
61d5171173SAlex Zinenko    vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
62d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.transfer_to_scf
63d5171173SAlex Zinenko    # CHECK-SAME: max_transfer_rank = 3
64d5171173SAlex Zinenko    # CHECK-SAME: full_unroll = true
65d5171173SAlex Zinenko    vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
66d5171173SAlex Zinenko
67d5171173SAlex Zinenko
68d5171173SAlex Zinenko@run_apply_patterns
69d5171173SAlex Zinenkodef enum_configurable_patterns():
70d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_contraction
71d5171173SAlex Zinenko    vector.ApplyLowerContractionPatternsOp()
72d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_contraction
73d5171173SAlex Zinenko    # CHECK-SAME: lowering_strategy = matmulintrinsics
74d5171173SAlex Zinenko    vector.ApplyLowerContractionPatternsOp(
7592233062Smax        lowering_strategy=vector.VectorContractLowering.Matmul
76d5171173SAlex Zinenko    )
77d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_contraction
78d5171173SAlex Zinenko    # CHECK-SAME: lowering_strategy = parallelarith
79d5171173SAlex Zinenko    vector.ApplyLowerContractionPatternsOp(
8092233062Smax        lowering_strategy=vector.VectorContractLowering.ParallelArith
81d5171173SAlex Zinenko    )
82d5171173SAlex Zinenko
83d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_multi_reduction
84d5171173SAlex Zinenko    vector.ApplyLowerMultiReductionPatternsOp()
85d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_multi_reduction
86d5171173SAlex Zinenko    # This is the default mode, not printed.
87d5171173SAlex Zinenko    vector.ApplyLowerMultiReductionPatternsOp(
8892233062Smax        lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
89d5171173SAlex Zinenko    )
90d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_multi_reduction
91d5171173SAlex Zinenko    # CHECK-SAME: lowering_strategy = innerreduction
92d5171173SAlex Zinenko    vector.ApplyLowerMultiReductionPatternsOp(
9392233062Smax        lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
94d5171173SAlex Zinenko    )
95d5171173SAlex Zinenko
96d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_transpose
97d5171173SAlex Zinenko    vector.ApplyLowerTransposePatternsOp()
98d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_transpose
99*f54dc7b3SBenjamin Maxwell    # This is the default strategy, not printed.
100d5171173SAlex Zinenko    vector.ApplyLowerTransposePatternsOp(
10192233062Smax        lowering_strategy=vector.VectorTransposeLowering.EltWise
102d5171173SAlex Zinenko    )
103d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_transpose
104d5171173SAlex Zinenko    # CHECK-SAME: lowering_strategy = flat_transpose
105d5171173SAlex Zinenko    vector.ApplyLowerTransposePatternsOp(
10692233062Smax        lowering_strategy=vector.VectorTransposeLowering.Flat
107d5171173SAlex Zinenko    )
108d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_transpose
109d5171173SAlex Zinenko    # CHECK-SAME: lowering_strategy = shuffle_1d
110d5171173SAlex Zinenko    vector.ApplyLowerTransposePatternsOp(
11192233062Smax        lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
112d5171173SAlex Zinenko    )
113d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_transpose
114d5171173SAlex Zinenko    # CHECK-SAME: lowering_strategy = shuffle_16x16
115d5171173SAlex Zinenko    vector.ApplyLowerTransposePatternsOp(
11692233062Smax        lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
117d5171173SAlex Zinenko    )
118d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.lower_transpose
119d5171173SAlex Zinenko    # CHECK-SAME: lowering_strategy = flat_transpose
120d5171173SAlex Zinenko    # CHECK-SAME: avx2_lowering_strategy = true
121d5171173SAlex Zinenko    vector.ApplyLowerTransposePatternsOp(
12292233062Smax        lowering_strategy=vector.VectorTransposeLowering.Flat,
123d5171173SAlex Zinenko        avx2_lowering_strategy=True,
124d5171173SAlex Zinenko    )
125d5171173SAlex Zinenko
126d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
127d5171173SAlex Zinenko    vector.ApplySplitTransferFullPartialPatternsOp()
128d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
129d5171173SAlex Zinenko    # CHECK-SAME: split_transfer_strategy = none
130d5171173SAlex Zinenko    vector.ApplySplitTransferFullPartialPatternsOp(
13192233062Smax        split_transfer_strategy=vector.VectorTransferSplit.None_
132d5171173SAlex Zinenko    )
133d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
134d5171173SAlex Zinenko    # CHECK-SAME: split_transfer_strategy = "vector-transfer"
135d5171173SAlex Zinenko    vector.ApplySplitTransferFullPartialPatternsOp(
13692233062Smax        split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
137d5171173SAlex Zinenko    )
138d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
139d5171173SAlex Zinenko    # This is the default mode, not printed.
140d5171173SAlex Zinenko    vector.ApplySplitTransferFullPartialPatternsOp(
14192233062Smax        split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
142d5171173SAlex Zinenko    )
143d5171173SAlex Zinenko    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
144d5171173SAlex Zinenko    # CHECK-SAME: split_transfer_strategy = "force-in-bounds"
145d5171173SAlex Zinenko    vector.ApplySplitTransferFullPartialPatternsOp(
14692233062Smax        split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
147d5171173SAlex Zinenko    )
148