xref: /llvm-project/mlir/test/python/dialects/transform_vector_ext.py (revision f54dc7b3936f1bd751db710cfc2fec1652159a3f)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects.transform import vector
6
7
8def run_apply_patterns(f):
9    with Context(), Location.unknown():
10        module = Module.create()
11        with InsertionPoint(module.body):
12            sequence = transform.SequenceOp(
13                transform.FailurePropagationMode.Propagate,
14                [],
15                transform.AnyOpType.get(),
16            )
17            with InsertionPoint(sequence.body):
18                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
19                with InsertionPoint(apply.patterns):
20                    f()
21                transform.YieldOp()
22        print("\nTEST:", f.__name__)
23        print(module)
24    return f
25
26
27@run_apply_patterns
28def non_configurable_patterns():
29    # CHECK-LABEL: TEST: non_configurable_patterns
30    # CHECK: apply_patterns
31    # CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
32    vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
33    # CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
34    vector.ApplyRankReducingSubviewPatternsOp()
35    # CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
36    vector.ApplyTransferPermutationPatternsOp()
37    # CHECK: transform.apply_patterns.vector.lower_broadcast
38    vector.ApplyLowerBroadcastPatternsOp()
39    # CHECK: transform.apply_patterns.vector.lower_masks
40    vector.ApplyLowerMasksPatternsOp()
41    # CHECK: transform.apply_patterns.vector.lower_masked_transfers
42    vector.ApplyLowerMaskedTransfersPatternsOp()
43    # CHECK: transform.apply_patterns.vector.materialize_masks
44    vector.ApplyMaterializeMasksPatternsOp()
45    # CHECK: transform.apply_patterns.vector.lower_outerproduct
46    vector.ApplyLowerOuterProductPatternsOp()
47    # CHECK: transform.apply_patterns.vector.lower_gather
48    vector.ApplyLowerGatherPatternsOp()
49    # CHECK: transform.apply_patterns.vector.lower_scan
50    vector.ApplyLowerScanPatternsOp()
51    # CHECK: transform.apply_patterns.vector.lower_shape_cast
52    vector.ApplyLowerShapeCastPatternsOp()
53
54
55@run_apply_patterns
56def configurable_patterns():
57    # CHECK-LABEL: TEST: configurable_patterns
58    # CHECK: apply_patterns
59    # CHECK: transform.apply_patterns.vector.lower_transfer
60    # CHECK-SAME: max_transfer_rank = 4
61    vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
62    # CHECK: transform.apply_patterns.vector.transfer_to_scf
63    # CHECK-SAME: max_transfer_rank = 3
64    # CHECK-SAME: full_unroll = true
65    vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
66
67
68@run_apply_patterns
69def enum_configurable_patterns():
70    # CHECK: transform.apply_patterns.vector.lower_contraction
71    vector.ApplyLowerContractionPatternsOp()
72    # CHECK: transform.apply_patterns.vector.lower_contraction
73    # CHECK-SAME: lowering_strategy = matmulintrinsics
74    vector.ApplyLowerContractionPatternsOp(
75        lowering_strategy=vector.VectorContractLowering.Matmul
76    )
77    # CHECK: transform.apply_patterns.vector.lower_contraction
78    # CHECK-SAME: lowering_strategy = parallelarith
79    vector.ApplyLowerContractionPatternsOp(
80        lowering_strategy=vector.VectorContractLowering.ParallelArith
81    )
82
83    # CHECK: transform.apply_patterns.vector.lower_multi_reduction
84    vector.ApplyLowerMultiReductionPatternsOp()
85    # CHECK: transform.apply_patterns.vector.lower_multi_reduction
86    # This is the default mode, not printed.
87    vector.ApplyLowerMultiReductionPatternsOp(
88        lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
89    )
90    # CHECK: transform.apply_patterns.vector.lower_multi_reduction
91    # CHECK-SAME: lowering_strategy = innerreduction
92    vector.ApplyLowerMultiReductionPatternsOp(
93        lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
94    )
95
96    # CHECK: transform.apply_patterns.vector.lower_transpose
97    vector.ApplyLowerTransposePatternsOp()
98    # CHECK: transform.apply_patterns.vector.lower_transpose
99    # This is the default strategy, not printed.
100    vector.ApplyLowerTransposePatternsOp(
101        lowering_strategy=vector.VectorTransposeLowering.EltWise
102    )
103    # CHECK: transform.apply_patterns.vector.lower_transpose
104    # CHECK-SAME: lowering_strategy = flat_transpose
105    vector.ApplyLowerTransposePatternsOp(
106        lowering_strategy=vector.VectorTransposeLowering.Flat
107    )
108    # CHECK: transform.apply_patterns.vector.lower_transpose
109    # CHECK-SAME: lowering_strategy = shuffle_1d
110    vector.ApplyLowerTransposePatternsOp(
111        lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
112    )
113    # CHECK: transform.apply_patterns.vector.lower_transpose
114    # CHECK-SAME: lowering_strategy = shuffle_16x16
115    vector.ApplyLowerTransposePatternsOp(
116        lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
117    )
118    # CHECK: transform.apply_patterns.vector.lower_transpose
119    # CHECK-SAME: lowering_strategy = flat_transpose
120    # CHECK-SAME: avx2_lowering_strategy = true
121    vector.ApplyLowerTransposePatternsOp(
122        lowering_strategy=vector.VectorTransposeLowering.Flat,
123        avx2_lowering_strategy=True,
124    )
125
126    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
127    vector.ApplySplitTransferFullPartialPatternsOp()
128    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
129    # CHECK-SAME: split_transfer_strategy = none
130    vector.ApplySplitTransferFullPartialPatternsOp(
131        split_transfer_strategy=vector.VectorTransferSplit.None_
132    )
133    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
134    # CHECK-SAME: split_transfer_strategy = "vector-transfer"
135    vector.ApplySplitTransferFullPartialPatternsOp(
136        split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
137    )
138    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
139    # This is the default mode, not printed.
140    vector.ApplySplitTransferFullPartialPatternsOp(
141        split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
142    )
143    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
144    # CHECK-SAME: split_transfer_strategy = "force-in-bounds"
145    vector.ApplySplitTransferFullPartialPatternsOp(
146        split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
147    )
148