xref: /llvm-project/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir (revision 8a5f33fd12621c8ac0def0481700246a34f4f674)
1// RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file | FileCheck %s
2
3// CHECK-LABEL: @outerproduct_add_widening_2way_f16f16f32
4// CHECK-SAME:    %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
5// CHECK-SAME:    %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>
6// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32>
7// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[A0]], %[[A1]] : vector<[4]xf16> -> vector<[8]xf16>
8// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[B0]], %[[B1]] : vector<[4]xf16> -> vector<[8]xf16>
9// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
10// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
11// CHECK-DAG: arm_sme.fmopa_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
12func.func @outerproduct_add_widening_2way_f16f16f32(
13    %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
14    %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
15    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
16    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
17  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
18  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
19  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
20  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
21
22  %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32>
23
24  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32>
25  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
26
27  return %1 : vector<[4]x[4]xf32>
28}
29
30// -----
31
32/// Verify chain of 4 outer products are fused into 2 2-way widening outer
33/// products.
34
35// CHECK-LABEL: @outerproduct_x2_add_widening_2way_f16f16f32
36// CHECK-COUNT-2: arm_sme.fmopa_2way
37func.func @outerproduct_x2_add_widening_2way_f16f16f32(
38    %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
39    %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
40    %a2 : vector<[4]xf16>, %b2 : vector<[4]xf16>,
41    %a3 : vector<[4]xf16>, %b3 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
42  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
43  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
44
45  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
46  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
47
48  %a2_ext = arith.extf %a2 : vector<[4]xf16> to vector<[4]xf32>
49  %b2_ext = arith.extf %b2 : vector<[4]xf16> to vector<[4]xf32>
50
51  %a3_ext = arith.extf %a3 : vector<[4]xf16> to vector<[4]xf32>
52  %b3_ext = arith.extf %b3 : vector<[4]xf16> to vector<[4]xf32>
53
54  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
55  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
56  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xf32>, vector<[4]xf32>
57  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xf32>, vector<[4]xf32>
58
59  return %3 : vector<[4]x[4]xf32>
60}
61
62// -----
63
64// CHECK-LABEL: @outerproduct_sub_widening_2way_f16f16f32
65// CHECK: arm_sme.fmops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
66func.func @outerproduct_sub_widening_2way_f16f16f32(
67    %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
68    %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
69    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
70    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
71  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
72  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
73  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
74  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
75
76  %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32>
77
78  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32>
79  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
80
81  return %1 : vector<[4]x[4]xf32>
82}
83
84// -----
85
86// CHECK-LABEL: @outerproduct_add_widening_2way_bf16bf16f32
87// CHECK: arm_sme.fmopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
88func.func @outerproduct_add_widening_2way_bf16bf16f32(
89    %a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>,
90    %a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>,
91    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
92    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
93  %a0_ext = arith.extf %a0 : vector<[4]xbf16> to vector<[4]xf32>
94  %b0_ext = arith.extf %b0 : vector<[4]xbf16> to vector<[4]xf32>
95  %a1_ext = arith.extf %a1 : vector<[4]xbf16> to vector<[4]xf32>
96  %b1_ext = arith.extf %b1 : vector<[4]xbf16> to vector<[4]xf32>
97
98  %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32>
99
100  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32>
101  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
102
103  return %1 : vector<[4]x[4]xf32>
104}
105
106// -----
107
108// CHECK-LABEL: @outerproduct_sub_widening_2way_bf16bf16f32
109// CHECK: arm_sme.fmops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
110func.func @outerproduct_sub_widening_2way_bf16bf16f32(
111    %a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>,
112    %a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>,
113    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
114    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
115  %a0_ext = arith.extf %a0 : vector<[4]xbf16> to vector<[4]xf32>
116  %b0_ext = arith.extf %b0 : vector<[4]xbf16> to vector<[4]xf32>
117  %a1_ext = arith.extf %a1 : vector<[4]xbf16> to vector<[4]xf32>
118  %b1_ext = arith.extf %b1 : vector<[4]xbf16> to vector<[4]xf32>
119
120  %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32>
121
122  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32>
123  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
124
125  return %1 : vector<[4]x[4]xf32>
126}
127
128// -----
129
130// CHECK-LABEL: @outerproduct_add_widening_2way_signed_i16i16i32
131// CHECK: arm_sme.smopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
132func.func @outerproduct_add_widening_2way_signed_i16i16i32(
133    %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
134    %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
135    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
136    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
137  %a0_ext = arith.extsi %a0 : vector<[4]xi16> to vector<[4]xi32>
138  %b0_ext = arith.extsi %b0 : vector<[4]xi16> to vector<[4]xi32>
139  %a1_ext = arith.extsi %a1 : vector<[4]xi16> to vector<[4]xi32>
140  %b1_ext = arith.extsi %b1 : vector<[4]xi16> to vector<[4]xi32>
141
142  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
143
144  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
145  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
146
147  return %1 : vector<[4]x[4]xi32>
148}
149
150// -----
151
152// CHECK-LABEL: @outerproduct_sub_widening_2way_signed_i16i16i32
153// CHECK: arm_sme.smops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
154func.func @outerproduct_sub_widening_2way_signed_i16i16i32(
155    %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
156    %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
157    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
158    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
159  %a0_ext = arith.extsi %a0 : vector<[4]xi16> to vector<[4]xi32>
160  %b0_ext = arith.extsi %b0 : vector<[4]xi16> to vector<[4]xi32>
161  %a1_ext = arith.extsi %a1 : vector<[4]xi16> to vector<[4]xi32>
162  %b1_ext = arith.extsi %b1 : vector<[4]xi16> to vector<[4]xi32>
163
164  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
165
166  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
167  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
168
169  return %1 : vector<[4]x[4]xi32>
170}
171
172// -----
173
174// CHECK-LABEL: @outerproduct_add_widening_2way_unsigned_i16i16i32
175// CHECK: arm_sme.umopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
176func.func @outerproduct_add_widening_2way_unsigned_i16i16i32(
177    %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
178    %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
179    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
180    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
181  %a0_ext = arith.extui %a0 : vector<[4]xi16> to vector<[4]xi32>
182  %b0_ext = arith.extui %b0 : vector<[4]xi16> to vector<[4]xi32>
183  %a1_ext = arith.extui %a1 : vector<[4]xi16> to vector<[4]xi32>
184  %b1_ext = arith.extui %b1 : vector<[4]xi16> to vector<[4]xi32>
185
186  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
187
188  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
189  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
190
191  return %1 : vector<[4]x[4]xi32>
192}
193
194// -----
195
196// CHECK-LABEL: @outerproduct_sub_widening_2way_unsigned_i16i16i32
197// CHECK: arm_sme.umops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
198func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
199    %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
200    %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
201    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
202    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
203  %a0_ext = arith.extui %a0 : vector<[4]xi16> to vector<[4]xi32>
204  %b0_ext = arith.extui %b0 : vector<[4]xi16> to vector<[4]xi32>
205  %a1_ext = arith.extui %a1 : vector<[4]xi16> to vector<[4]xi32>
206  %b1_ext = arith.extui %b1 : vector<[4]xi16> to vector<[4]xi32>
207
208  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
209
210  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
211  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
212
213  return %1 : vector<[4]x[4]xi32>
214}
215
216// -----
217
218// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i8i8i32
219// CHECK-SAME:    %[[A0:[a-z0-9]+]]: vector<[4]xi8>, %[[B0:[a-z0-9]+]]: vector<[4]xi8>,
220// CHECK-SAME:    %[[A1:[a-z0-9]+]]: vector<[4]xi8>, %[[B1:[a-z0-9]+]]: vector<[4]xi8>,
221// CHECK-SAME:    %[[A2:[a-z0-9]+]]: vector<[4]xi8>, %[[B2:[a-z0-9]+]]: vector<[4]xi8>,
222// CHECK-SAME:    %[[A3:[a-z0-9]+]]: vector<[4]xi8>, %[[B3:[a-z0-9]+]]: vector<[4]xi8>,
223// CHECK-SAME:    %[[A0_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B0_MASK:[a-z0-9]+]]: vector<[4]xi1>,
224// CHECK-SAME:    %[[A1_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B1_MASK:[a-z0-9]+]]: vector<[4]xi1>,
225// CHECK-SAME:    %[[A2_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B2_MASK:[a-z0-9]+]]: vector<[4]xi1>,
226// CHECK-SAME:    %[[A3_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B3_MASK:[a-z0-9]+]]: vector<[4]xi1>
227// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32>
228// CHECK-DAG: %[[LHS0:.*]] = vector.interleave %[[A0]], %[[A2]] : vector<[4]xi8> -> vector<[8]xi8>
229// CHECK-DAG: %[[LHS1:.*]] = vector.interleave %[[A1]], %[[A3]] : vector<[4]xi8> -> vector<[8]xi8>
230// CHECK-DAG: %[[RHS0:.*]] = vector.interleave %[[B0]], %[[B2]] : vector<[4]xi8> -> vector<[8]xi8>
231// CHECK-DAG: %[[RHS1:.*]] = vector.interleave %[[B1]], %[[B3]] : vector<[4]xi8> -> vector<[8]xi8>
232// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[LHS0]], %[[LHS1]] : vector<[8]xi8> -> vector<[16]xi8>
233// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[RHS0]], %[[RHS1]] : vector<[8]xi8> -> vector<[16]xi8>
234// CHECK-DAG: %[[LHS0_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
235// CHECK-DAG: %[[LHS1_MASK:.*]] = vector.interleave %[[A1_MASK]], %[[A3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
236// CHECK-DAG: %[[RHS0_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
237// CHECK-DAG: %[[RHS1_MASK:.*]] = vector.interleave %[[B1_MASK]], %[[B3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
238// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[LHS0_MASK]], %[[LHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
239// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[RHS0_MASK]], %[[RHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
240// CHECK-DAG: arm_sme.smopa_4way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
241func.func @outerproduct_add_widening_4way_signed_i8i8i32(
242    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
243    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
244    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
245    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
246    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
247    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
248    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
249    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
250  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
251  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
252
253  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
254  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
255
256  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
257  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
258
259  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
260  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
261
262  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
263
264  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
265  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
266  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
267  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
268
269  return %3 : vector<[4]x[4]xi32>
270}
271
272// -----
273
274// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i8i8i32
275// CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
276func.func @outerproduct_sub_widening_4way_signed_i8i8i32(
277    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
278    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
279    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
280    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
281    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
282    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
283    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
284    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
285  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
286  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
287
288  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
289  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
290
291  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
292  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
293
294  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
295  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
296
297  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
298
299  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
300  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
301  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
302  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
303
304  return %3 : vector<[4]x[4]xi32>
305}
306
307// -----
308
309// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i16i16i64
310// CHECK: arm_sme.smopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
311func.func @outerproduct_add_widening_4way_signed_i16i16i64(
312    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
313    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
314    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
315    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
316    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
317    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
318    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
319    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
320  %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
321  %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
322
323  %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
324  %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
325
326  %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
327  %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
328
329  %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
330  %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
331
332  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
333
334  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
335  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
336  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
337  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
338
339  return %3 : vector<[2]x[2]xi64>
340}
341
342// -----
343
344// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i16i16i64
345// CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
346func.func @outerproduct_sub_widening_4way_signed_i16i16i64(
347    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
348    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
349    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
350    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
351    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
352    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
353    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
354    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
355  %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
356  %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
357
358  %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
359  %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
360
361  %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
362  %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
363
364  %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
365  %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
366
367  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
368
369  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
370  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
371  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
372  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
373
374  return %3 : vector<[2]x[2]xi64>
375}
376
377// -----
378
379// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i8i8i32
380// CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
381func.func @outerproduct_add_widening_4way_unsigned_i8i8i32(
382    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
383    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
384    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
385    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
386    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
387    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
388    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
389    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
390  %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
391  %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
392
393  %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
394  %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
395
396  %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
397  %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
398
399  %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
400  %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
401
402  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
403
404  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
405  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
406  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
407  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
408
409  return %3 : vector<[4]x[4]xi32>
410}
411
412// -----
413
414// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i8i8i32
415// CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
416func.func @outerproduct_sub_widening_4way_unsigned_i8i8i32(
417    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
418    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
419    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
420    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
421    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
422    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
423    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
424    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
425  %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
426  %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
427
428  %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
429  %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
430
431  %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
432  %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
433
434  %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
435  %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
436
437  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
438
439  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
440  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
441  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
442  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
443
444  return %3 : vector<[4]x[4]xi32>
445}
446
447// -----
448
449// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i16i16i64
450// CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
451func.func @outerproduct_add_widening_4way_unsigned_i16i16i64(
452    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
453    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
454    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
455    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
456    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
457    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
458    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
459    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
460  %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
461  %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
462
463  %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
464  %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
465
466  %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
467  %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
468
469  %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
470  %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
471
472  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
473
474  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
475  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
476  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
477  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
478
479  return %3 : vector<[2]x[2]xi64>
480}
481
482// -----
483
484// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i16i16i64
485// CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
486func.func @outerproduct_sub_widening_4way_unsigned_i16i16i64(
487    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
488    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
489    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
490    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
491    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
492    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
493    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
494    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
495  %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
496  %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
497
498  %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
499  %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
500
501  %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
502  %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
503
504  %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
505  %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
506
507  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
508
509  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
510  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
511  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
512  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
513
514  return %3 : vector<[2]x[2]xi64>
515}
516
517// -----
518
519// CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32
520// CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
521func.func @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32(
522    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
523    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
524    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
525    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
526    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
527    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
528    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
529    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
530  %a0_sext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
531  %b0_zext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
532
533  %a1_sext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
534  %b1_zext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
535
536  %a2_sext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
537  %b2_zext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
538
539  %a3_sext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
540  %b3_zext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
541
542  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
543
544  %0 = arm_sme.outerproduct %a0_sext, %b0_zext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
545  %1 = arm_sme.outerproduct %a1_sext, %b1_zext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
546  %2 = arm_sme.outerproduct %a2_sext, %b2_zext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
547  %3 = arm_sme.outerproduct %a3_sext, %b3_zext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
548
549  return %3 : vector<[4]x[4]xi32>
550}
551
552// -----
553
554// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32
555// CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
556func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32(
557    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
558    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
559    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
560    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
561    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
562    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
563    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
564    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
565  %a0_sext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
566  %b0_zext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
567
568  %a1_sext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
569  %b1_zext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
570
571  %a2_sext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
572  %b2_zext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
573
574  %a3_sext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
575  %b3_zext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
576
577  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
578
579  %0 = arm_sme.outerproduct %a0_sext, %b0_zext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
580  %1 = arm_sme.outerproduct %a1_sext, %b1_zext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
581  %2 = arm_sme.outerproduct %a2_sext, %b2_zext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
582  %3 = arm_sme.outerproduct %a3_sext, %b3_zext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
583
584  return %3 : vector<[4]x[4]xi32>
585}
586
587// -----
588
589// CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64
590// CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
591func.func @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64(
592    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
593    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
594    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
595    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
596    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
597    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
598    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
599    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
600  %a0_sext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
601  %b0_zext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
602
603  %a1_sext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
604  %b1_zext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
605
606  %a2_sext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
607  %b2_zext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
608
609  %a3_sext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
610  %b3_zext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
611
612  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
613
614  %0 = arm_sme.outerproduct %a0_sext, %b0_zext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
615  %1 = arm_sme.outerproduct %a1_sext, %b1_zext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
616  %2 = arm_sme.outerproduct %a2_sext, %b2_zext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
617  %3 = arm_sme.outerproduct %a3_sext, %b3_zext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
618
619  return %3 : vector<[2]x[2]xi64>
620}
621
622// -----
623
624// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64
625// CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
626func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64(
627    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
628    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
629    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
630    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
631    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
632    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
633    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
634    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
635  %a0_sext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
636  %b0_zext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
637
638  %a1_sext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
639  %b1_zext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
640
641  %a2_sext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
642  %b2_zext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
643
644  %a3_sext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
645  %b3_zext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
646
647  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
648
649  %0 = arm_sme.outerproduct %a0_sext, %b0_zext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
650  %1 = arm_sme.outerproduct %a1_sext, %b1_zext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
651  %2 = arm_sme.outerproduct %a2_sext, %b2_zext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
652  %3 = arm_sme.outerproduct %a3_sext, %b3_zext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
653
654  return %3 : vector<[2]x[2]xi64>
655}
656
657// -----
658
659// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32
660// CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
661func.func @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32(
662    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
663    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
664    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
665    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
666    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
667    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
668    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
669    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
670  %a0_zext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
671  %b0_sext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
672
673  %a1_zext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
674  %b1_sext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
675
676  %a2_zext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
677  %b2_sext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
678
679  %a3_zext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
680  %b3_sext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
681
682  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
683
684  %0 = arm_sme.outerproduct %a0_zext, %b0_sext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
685  %1 = arm_sme.outerproduct %a1_zext, %b1_sext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
686  %2 = arm_sme.outerproduct %a2_zext, %b2_sext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
687  %3 = arm_sme.outerproduct %a3_zext, %b3_sext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
688
689  return %3 : vector<[4]x[4]xi32>
690}
691
692// -----
693
694// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32
695// CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
696func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32(
697    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
698    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
699    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
700    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
701    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
702    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
703    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
704    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
705  %a0_zext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
706  %b0_sext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
707
708  %a1_zext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
709  %b1_sext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
710
711  %a2_zext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
712  %b2_sext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
713
714  %a3_zext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
715  %b3_sext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
716
717  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
718
719  %0 = arm_sme.outerproduct %a0_zext, %b0_sext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
720  %1 = arm_sme.outerproduct %a1_zext, %b1_sext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
721  %2 = arm_sme.outerproduct %a2_zext, %b2_sext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
722  %3 = arm_sme.outerproduct %a3_zext, %b3_sext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
723
724  return %3 : vector<[4]x[4]xi32>
725}
726
727// -----
728
729// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64
730// CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
731func.func @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64(
732    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
733    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
734    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
735    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
736    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
737    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
738    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
739    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
740  %a0_zext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
741  %b0_sext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
742
743  %a1_zext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
744  %b1_sext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
745
746  %a2_zext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
747  %b2_sext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
748
749  %a3_zext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
750  %b3_sext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
751
752  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
753
754  %0 = arm_sme.outerproduct %a0_zext, %b0_sext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
755  %1 = arm_sme.outerproduct %a1_zext, %b1_sext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
756  %2 = arm_sme.outerproduct %a2_zext, %b2_sext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
757  %3 = arm_sme.outerproduct %a3_zext, %b3_sext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
758
759  return %3 : vector<[2]x[2]xi64>
760}
761
762// -----
763
764// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64
765// CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
766func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64(
767    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
768    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
769    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
770    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
771    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
772    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
773    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
774    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
775  %a0_zext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
776  %b0_sext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
777
778  %a1_zext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
779  %b1_sext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
780
781  %a2_zext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
782  %b2_sext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
783
784  %a3_zext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
785  %b3_sext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
786
787  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
788
789  %0 = arm_sme.outerproduct %a0_zext, %b0_sext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
790  %1 = arm_sme.outerproduct %a1_zext, %b1_sext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
791  %2 = arm_sme.outerproduct %a2_zext, %b2_sext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
792  %3 = arm_sme.outerproduct %a3_zext, %b3_sext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
793
794  return %3 : vector<[2]x[2]xi64>
795}
796
797/// Tests for related patterns.
798
799// -----
800
801// CHECK-LABEL: @extract_from_arith_ext(
802// CHECK-SAME:                          %[[SRC:.*]]: vector<4x[8]xi8>
803// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][0] : vector<[8]xi8> from vector<4x[8]xi8>
804// CHECK: %[[EXTEND:.*]] = arith.extsi %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32>
805// CHECK: return %[[EXTEND]]
806func.func @extract_from_arith_ext(%src: vector<4x[8]xi8>) -> vector<[8]xi32> {
807  %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
808  %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
809  return %1 : vector<[8]xi32>
810}
811
812// -----
813
814// CHECK-LABEL: @non_constant_extract_from_arith_ext(
815// CHECK-SAME:                                       %[[SRC:[a-z0-9]+]]: vector<4x[8]xi8>,
816// CHECK-SAME:                                       %[[DIM:[a-z0-9]+]]: index
817// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]]] : vector<[8]xi8> from vector<4x[8]xi8>
818// CHECK: %[[EXTEND:.*]] = arith.extui %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32>
819// CHECK: return %[[EXTEND]]
820func.func @non_constant_extract_from_arith_ext(%src: vector<4x[8]xi8>, %dim: index) -> vector<[8]xi32> {
821  %0 = arith.extui %src : vector<4x[8]xi8> to vector<4x[8]xi32>
822  %1 = vector.extract %0[%dim] : vector<[8]xi32> from vector<4x[8]xi32>
823  return %1 : vector<[8]xi32>
824}
825
826// -----
827
828// CHECK-LABEL: @scalable_extract_from_arith_ext(
829// CHECK-SAME:                                   %[[SRC:.*]]: vector<[8]xf16>
830// CHECK: %[[EXTRACT:.*]] = vector.scalable.extract %[[SRC]][0] : vector<[4]xf16> from vector<[8]xf16>
831// CHECK: %[[EXTEND:.*]] = arith.extf %[[EXTRACT]] : vector<[4]xf16> to vector<[4]xf32>
832// CHECK: return %[[EXTEND]]
833func.func @scalable_extract_from_arith_ext(%src: vector<[8]xf16>) -> vector<[4]xf32> {
834  %0 = arith.extf %src : vector<[8]xf16> to vector<[8]xf32>
835  %1 = vector.scalable.extract %0[0] : vector<[4]xf32> from vector<[8]xf32>
836  return %1 : vector<[4]xf32>
837}
838
839/// Negative tests
840
841// -----
842
843// CHECK-LABEL: @outerproduct_widening_2way__no_acc
844// CHECK-NOT: arm_sme.fmopa_2way
845// CHECK: arm_sme.outerproduct
846// CHECK-NOT: arm_sme.fmopa_2way
847func.func @outerproduct_widening_2way__no_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
848  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
849  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
850
851  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
852
853  return %0 : vector<[4]x[4]xf32>
854}
855
856// -----
857
858// CHECK-LABEL: @outerproduct_widening_4way__no_acc
859// CHECK-NOT: arm_sme.fmopa_4way
860// CHECK: arm_sme.outerproduct
861// CHECK: arm_sme.outerproduct
862// CHECK: arm_sme.outerproduct
863// CHECK-NOT: arm_sme.fmopa_4way
864func.func @outerproduct_widening_4way__no_acc(
865    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
866    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
867    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
868  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
869  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
870
871  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
872  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
873
874  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
875  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
876
877  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
878  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
879  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
880
881  return %2 : vector<[4]x[4]xi32>
882}
883
884// -----
885
886/// Defining op of accumulator operand must be an 'arm_sme.outerproduct'.
887
888// CHECK-LABEL: @outerproduct_widening_2way__bad_acc
889// CHECK-NOT: arm_sme.fmopa_2way
890// CHECK: arm_sme.outerproduct
891// CHECK-NOT: arm_sme.fmopa_2way
892func.func @outerproduct_widening_2way__bad_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
893  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
894  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
895
896  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
897
898  return %0 : vector<[4]x[4]xf32>
899}
900
901// -----
902
903// CHECK-LABEL: @outerproduct_widening_4way__missing_acc
904// CHECK-NOT: arm_sme.fmopa_4way
905// CHECK: arm_sme.outerproduct
906// CHECK: arm_sme.outerproduct
907// CHECK: arm_sme.outerproduct
908// CHECK: arm_sme.outerproduct
909// CHECK-NOT: arm_sme.fmopa_4way
910func.func @outerproduct_widening_4way__missing_acc(
911    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
912    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
913    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
914    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
915  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
916  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
917
918  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
919  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
920
921  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
922  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
923
924  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
925  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
926
927  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
928  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
929  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
930  // Missing accumulator breaks use-def chain.
931  %3 = arm_sme.outerproduct %a3_ext, %b3_ext : vector<[4]xi32>, vector<[4]xi32>
932  "test.some_use"(%2) : (vector<[4]x[4]xi32>) -> ()
933
934  return %3 : vector<[4]x[4]xi32>
935}
936
937// -----
938
939/// Combining kinds of outer products must match to be fused.
940
941// CHECK-LABEL: @outerproduct_widening_2way__bad_combining_kind
942// CHECK-NOT: arm_sme.fmopa_2way
943// CHECK: arm_sme.outerproduct
944// CHECK: arm_sme.outerproduct
945// CHECK-NOT: arm_sme.fmopa_2way
946func.func @outerproduct_widening_2way__bad_combining_kind(
947    %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
948    %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
949  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
950  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
951  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
952  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
953
954  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<add> : vector<[4]xf32>, vector<[4]xf32>
955  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) : vector<[4]xf32>, vector<[4]xf32>
956
957  return %1 : vector<[4]x[4]xf32>
958}
959
960// -----
961
962// CHECK-LABEL: @outerproduct_widening_4way__inconsistent_combining_kind
963// CHECK-NOT: arm_sme.fmopa_4way
964// CHECK: arm_sme.outerproduct
965// CHECK: arm_sme.outerproduct
966// CHECK: arm_sme.outerproduct
967// CHECK: arm_sme.outerproduct
968// CHECK-NOT: arm_sme.fmopa_4way
969func.func @outerproduct_widening_4way__inconsistent_combining_kind(
970    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
971    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
972    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
973    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
974  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
975  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
976
977  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
978  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
979
980  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
981  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
982
983  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
984  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
985
986  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> : vector<[4]xi32>, vector<[4]xi32>
987  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<add> acc(%0) : vector<[4]xi32>, vector<[4]xi32>
988  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<add> acc(%1) : vector<[4]xi32>, vector<[4]xi32>
989  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<add> acc(%2) : vector<[4]xi32>, vector<[4]xi32>
990
991  return %3 : vector<[4]x[4]xi32>
992}
993
994// -----
995
996/// If the first outer product has uses other than as the input to another
997/// outer product, it can't be erased after fusion. This is a problem when
998/// it also has an accumulator as this will be used as the root for tile
999/// allocation and since the widening outer product uses the same
1000/// accumulator it will get assigned the same tile ID, resulting in 3
1001/// outer products and incorrect results. Check this is prevented.
1002
1003// CHECK-LABEL: @outerproduct_widening_2way__cant_erase
1004// CHECK-NOT: arm_sme.fmopa_2way
1005// CHECK: arm_sme.outerproduct
1006// CHECK: arm_sme.outerproduct
1007// CHECK-NOT: arm_sme.fmopa_2way
1008func.func @outerproduct_widening_2way__cant_erase(
1009    %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
1010    %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
1011  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
1012  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
1013  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
1014  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
1015
1016  %acc = arith.constant dense<1.0> : vector<[4]x[4]xf32>
1017  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
1018  "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
1019  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
1020
1021  return %1 : vector<[4]x[4]xf32>
1022}
1023
1024// -----
1025
1026// CHECK-LABEL: @outerproduct_widening_4way__multi_use_cant_erase
1027// CHECK-NOT: arm_sme.fmopa_4way
1028// CHECK: arm_sme.outerproduct
1029// CHECK: arm_sme.outerproduct
1030// CHECK: arm_sme.outerproduct
1031// CHECK: arm_sme.outerproduct
1032// CHECK-NOT: arm_sme.fmopa_4way
1033func.func @outerproduct_widening_4way__multi_use_cant_erase(
1034    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
1035    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
1036    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
1037    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
1038  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
1039  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
1040
1041  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
1042  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
1043
1044  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
1045  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
1046
1047  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
1048  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
1049
1050  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
1051  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
1052  "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> ()
1053  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
1054  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
1055
1056  return %3 : vector<[4]x[4]xi32>
1057}
1058
1059// -----
1060
1061// CHECK-LABEL: @outerproduct_widening_2way__unsupported_type_f32f32f64
1062// CHECK-NOT: arm_sme.fmopa_2way
1063// CHECK: arm_sme.outerproduct
1064// CHECK: arm_sme.outerproduct
1065// CHECK-NOT: arm_sme.fmopa_2way
1066func.func @outerproduct_widening_2way__unsupported_type_f32f32f64(
1067    %a0 : vector<[2]xf32>, %b0 : vector<[2]xf32>,
1068    %a1 : vector<[2]xf32>, %b1 : vector<[2]xf32>) -> vector<[2]x[2]xf64> {
1069  %a0_ext = arith.extf %a0 : vector<[2]xf32> to vector<[2]xf64>
1070  %b0_ext = arith.extf %b0 : vector<[2]xf32> to vector<[2]xf64>
1071  %a1_ext = arith.extf %a1 : vector<[2]xf32> to vector<[2]xf64>
1072  %b1_ext = arith.extf %b1 : vector<[2]xf32> to vector<[2]xf64>
1073
1074  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64>
1075  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64>
1076
1077  return %1 : vector<[2]x[2]xf64>
1078}
1079
1080// -----
1081
1082// CHECK-LABEL: @outerproduct_widening_4way__unsupported_type_f16f16f64
1083// CHECK-NOT: arm_sme.fmopa_4way
1084// CHECK: arm_sme.outerproduct
1085// CHECK: arm_sme.outerproduct
1086// CHECK: arm_sme.outerproduct
1087// CHECK: arm_sme.outerproduct
1088// CHECK-NOT: arm_sme.fmopa_4way
1089func.func @outerproduct_widening_4way__unsupported_type_f16f16f64(
1090    %a0 : vector<[2]xf16>, %b0 : vector<[2]xf16>,
1091    %a1 : vector<[2]xf16>, %b1 : vector<[2]xf16>,
1092    %a2 : vector<[2]xf16>, %b2 : vector<[2]xf16>,
1093    %a3 : vector<[2]xf16>, %b3 : vector<[2]xf16>) -> vector<[2]x[2]xf64> {
1094  %a0_ext = arith.extf %a0 : vector<[2]xf16> to vector<[2]xf64>
1095  %b0_ext = arith.extf %b0 : vector<[2]xf16> to vector<[2]xf64>
1096
1097  %a1_ext = arith.extf %a1 : vector<[2]xf16> to vector<[2]xf64>
1098  %b1_ext = arith.extf %b1 : vector<[2]xf16> to vector<[2]xf64>
1099
1100  %a2_ext = arith.extf %a2 : vector<[2]xf16> to vector<[2]xf64>
1101  %b2_ext = arith.extf %b2 : vector<[2]xf16> to vector<[2]xf64>
1102
1103  %a3_ext = arith.extf %a3 : vector<[2]xf16> to vector<[2]xf64>
1104  %b3_ext = arith.extf %b3 : vector<[2]xf16> to vector<[2]xf64>
1105
1106  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64>
1107  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64>
1108  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[2]xf64>, vector<[2]xf64>
1109  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[2]xf64>, vector<[2]xf64>
1110
1111  return %3 : vector<[2]x[2]xf64>
1112}
1113
1114// -----
1115
1116/// Fusion only occurs if either both outer products are masked, or neither.
1117
1118// CHECK-LABEL: @outerproduct_widening_2way__bad_masking
1119// CHECK-NOT: arm_sme.fmopa_2way
1120// CHECK: arm_sme.outerproduct
1121// CHECK: arm_sme.outerproduct
1122// CHECK-NOT: arm_sme.fmopa_2way
1123func.func @outerproduct_widening_2way__bad_masking(
1124    %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
1125    %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
1126    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
1127  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
1128  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
1129  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
1130  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
1131
1132  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
1133  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
1134
1135  return %1 : vector<[4]x[4]xf32>
1136}
1137
1138// -----
1139
1140// CHECK-LABEL: @outerproduct_widening_4way__inconsistent_masking
1141// CHECK-NOT: arm_sme.fmopa_4way
1142// CHECK: arm_sme.outerproduct
1143// CHECK: arm_sme.outerproduct
1144// CHECK: arm_sme.outerproduct
1145// CHECK: arm_sme.outerproduct
1146// CHECK-NOT: arm_sme.fmopa_4way
1147func.func @outerproduct_widening_4way__inconsistent_masking(
1148    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
1149    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
1150    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
1151    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
1152    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
1153  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
1154  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
1155
1156  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
1157  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
1158
1159  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
1160  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
1161
1162  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
1163  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
1164
1165  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
1166  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
1167  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
1168  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
1169
1170  return %3 : vector<[4]x[4]xi32>
1171}
1172
1173// -----
1174
1175/// Defining op of outer product must be a supported extension op.
1176
1177// CHECK-LABEL: @outerproduct_widening_2way__bad_defining_op
1178// CHECK-NOT: arm_sme.fmopa_2way
1179// CHECK: arm_sme.outerproduct
1180// CHECK: arm_sme.outerproduct
1181// CHECK-NOT: arm_sme.fmopa_2way
1182func.func @outerproduct_widening_2way__bad_defining_op(
1183    %a0 : vector<[4]xf32>, %b0 : vector<[4]xf32>,
1184    %a1 : vector<[4]xf32>, %b1 : vector<[4]xf32>) -> vector<[4]x[4]xf32> {
1185  %0 = arm_sme.outerproduct %a0, %b0 : vector<[4]xf32>, vector<[4]xf32>
1186  %1 = arm_sme.outerproduct %a1, %b1 acc(%0) : vector<[4]xf32>, vector<[4]xf32>
1187
1188  return %1 : vector<[4]x[4]xf32>
1189}
1190
1191// -----
1192
1193// CHECK-LABEL: @outerproduct_widening_4way__bad_defining_op
1194// CHECK-NOT: arm_sme.fmopa_4way
1195// CHECK: arm_sme.outerproduct
1196// CHECK: arm_sme.outerproduct
1197// CHECK: arm_sme.outerproduct
1198// CHECK: arm_sme.outerproduct
1199// CHECK-NOT: arm_sme.fmopa_4way
1200func.func @outerproduct_widening_4way__bad_defining_op(
1201    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
1202    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
1203    %a2 : vector<[4]xi32>, %b2 : vector<[4]xi32>,
1204    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
1205  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
1206  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
1207
1208  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
1209  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
1210
1211  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
1212  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
1213
1214  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
1215  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
1216  /// Inputs must come from an arith.ext.
1217  %2 = arm_sme.outerproduct %a2, %b2 acc(%1) : vector<[4]xi32>, vector<[4]xi32>
1218  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
1219
1220  return %3 : vector<[4]x[4]xi32>
1221}
1222
1223/// Negative tests for related patterns.
1224
1225// -----
1226
1227/// Non-vector extracts should be ignored.
1228
1229// CHECK-LABEL: @extract_scalar_from_arith_ext
1230// CHECK-NEXT: arith.extsi
1231// CHECK-NEXT: vector.extract
1232func.func @extract_scalar_from_arith_ext(%src: vector<4x[8]xi8>) -> i32 {
1233  %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
1234  %1 = vector.extract %0[0, 0] : i32 from vector<4x[8]xi32>
1235  return %1 : i32
1236}
1237
1238// -----
1239
1240/// Extracted type should be a 1-D scalable vector type.
1241
1242// CHECK-LABEL: @extract_fixed_1d_vec_from_arith_ext
1243// CHECK-NEXT: arith.extsi
1244// CHECK-NEXT: vector.extract
1245func.func @extract_fixed_1d_vec_from_arith_ext(%src: vector<4x8xi8>) -> vector<8xi32> {
1246  %0 = arith.extsi %src : vector<4x8xi8> to vector<4x8xi32>
1247  %1 = vector.extract %0[0] : vector<8xi32> from vector<4x8xi32>
1248  return %1 : vector<8xi32>
1249}
1250
1251// -----
1252
1253/// Extract must come from an arith extend.
1254
1255// CHECK-LABEL: @extract_from_non_arith_ext
1256// CHECK-NEXT: vector.extract
1257// CHECK-NEXT: return
1258func.func @extract_from_non_arith_ext(%src: vector<4x[8]xi32>) -> vector<[8]xi32> {
1259  %0 = vector.extract %src[0] : vector<[8]xi32> from vector<4x[8]xi32>
1260  return %0 : vector<[8]xi32>
1261}
1262
1263// -----
1264
1265/// Scalable extract must come from an arith extend.
1266
1267// CHECK-LABEL: @scalable_extract_from_non_arith_ext
1268// CHECK-NEXT: vector.scalable.extract
1269// CHECK-NEXT: return
1270func.func @scalable_extract_from_non_arith_ext(%src: vector<[8]xf32>) -> vector<[4]xf32> {
1271  %0 = vector.scalable.extract %src[0] : vector<[4]xf32> from vector<[8]xf32>
1272  return %0 : vector<[4]xf32>
1273}
1274