xref: /llvm-project/mlir/test/Dialect/ArmSME/vector-legalization.mlir (revision a9eb8f0e3dbaf16b6bd83eecb960b6ea8ecaa8c3)
1// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize -split-input-file | FileCheck %s
2
3// CHECK-LABEL: @outerproduct_f32_scalable_8x8_no_acc(
4// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf32>,
5// CHECK-SAME:                                        %[[RHS:.*]]: vector<[8]xf32>)
6// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
7func.func @outerproduct_f32_scalable_8x8_no_acc(%lhs: vector<[8]xf32>, %rhs: vector<[8]xf32>) -> vector<[8]x[8]xf32>
8{
9  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[8]xf32>
10  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[8]xf32>
11  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[8]xf32>
12  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[8]xf32>
13  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
14  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
15  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
16  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
17  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
18  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf32>, vector<[8]xf32>
19  return %0 : vector<[8]x[8]xf32>
20}
21
22// -----
23
24// CHECK-LABEL: @outerproduct_f32_scalable_4x16_acc(
25// CHECK-SAME:                                      %[[LHS:.*]]: vector<[4]xf32>,
26// CHECK-SAME:                                      %[[RHS:.*]]: vector<[16]xf32>,
27// CHECK-SAME:                                      %[[ACC_0:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
28// CHECK-SAME:                                      %[[ACC_1:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
29// CHECK-SAME:                                      %[[ACC_2:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
30// CHECK-SAME:                                      %[[ACC_3:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>)
31// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
32func.func @outerproduct_f32_scalable_4x16_acc(%lhs: vector<[4]xf32>, %rhs: vector<[16]xf32>, %acc: vector<[4]x[16]xf32>) -> vector<[4]x[16]xf32>
33{
34  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[4]xf32>
35  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[16]xf32>
36  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[16]xf32>
37  // CHECK-DAG: %[[RHS_2:.*]] = vector.scalable.extract %[[RHS]][8] : vector<[4]xf32> from vector<[16]xf32>
38  // CHECK-DAG: %[[RHS_3:.*]] = vector.scalable.extract %[[RHS]][12] : vector<[4]xf32> from vector<[16]xf32>
39  // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]], %[[ACC_0]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
40  // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]], %[[ACC_1]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
41  // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_2]], %[[ACC_2]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
42  // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_3]], %[[ACC_3]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
43  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
44  %0 = vector.outerproduct %lhs, %rhs, %acc : vector<[4]xf32>, vector<[16]xf32>
45  return %0 : vector<[4]x[16]xf32>
46}
47
48// -----
49
50// CHECK-LABEL: @outerproduct_f32_masked_scalable_16x4(
51// CHECK-SAME:                                         %[[LHS:.*]]: vector<[16]xf32>,
52// CHECK-SAME:                                         %[[RHS:.*]]: vector<[4]xf32>,
53// CHECK-SAME:                                         %[[LHS_DIM:.*]]: index,
54// CHECK-SAME:                                         %[[RHS_DIM:.*]]: index)
55// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
56func.func @outerproduct_f32_masked_scalable_16x4(%lhs: vector<[16]xf32>, %rhs: vector<[4]xf32>, %lhs_dim: index, %rhs_dim: index) -> vector<[16]x[4]xf32>
57{
58  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
59  // CHECK-DAG: %[[MINUS_4:.*]] = arith.constant -4 : index
60  // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
61  // CHECK-DAG: %[[MINUS_12:.*]] = arith.constant -12 : index
62  // CHECK-DAG: %[[MINUS_4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_4]] : index
63  // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
64  // CHECK-DAG: %[[MINUS_12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_12]] : index
65  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[16]xf32>
66  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[16]xf32>
67  // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][8] : vector<[4]xf32> from vector<[16]xf32>
68  // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][12] : vector<[4]xf32> from vector<[16]xf32>
69  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[4]xf32>
70  // CHECK-DAG: %[[MASK_0:.*]] = vector.create_mask %[[LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
71  // CHECK-DAG: %[[TILE_1_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_4_VSCALE]] : index
72  // CHECK-DAG: %[[MASK_1:.*]] = vector.create_mask %[[TILE_1_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
73  // CHECK-DAG: %[[TILE_2_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_8_VSCALE]] : index
74  // CHECK-DAG: %[[MASK_2:.*]] = vector.create_mask %[[TILE_2_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
75  // CHECK-DAG: %[[TILE_3_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_12_VSCALE]] : index
76  // CHECK-DAG: %[[MASK_3:.*]] = vector.create_mask %[[TILE_3_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
77  // CHECK-DAG: %[[RES_0:.*]] = vector.mask %[[MASK_0]] { vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
78  // CHECK-DAG: %[[RES_1:.*]] = vector.mask %[[MASK_1]] { vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
79  // CHECK-DAG: %[[RES_2:.*]] = vector.mask %[[MASK_2]] { vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
80  // CHECK-DAG: %[[RES_3:.*]] = vector.mask %[[MASK_3]] { vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
81  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
82  %mask = vector.create_mask %lhs_dim, %rhs_dim : vector<[16]x[4]xi1>
83  %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs : vector<[16]xf32>, vector<[4]xf32> } : vector<[16]x[4]xi1> -> vector<[16]x[4]xf32>
84  return %0 : vector<[16]x[4]xf32>
85}
86
87// -----
88
89/// This demonstrates a rectangular tiling that uses all f64 accumulators.
90
91// CHECK-LABEL: @outerproduct_f64_scalable_8x4_no_acc(
92// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf64>,
93// CHECK-SAME:                                        %[[RHS:.*]]: vector<[4]xf64>)
94// CHECK-SAME: -> (vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>)
95func.func @outerproduct_f64_scalable_8x4_no_acc(%lhs: vector<[8]xf64>, %rhs: vector<[4]xf64>) -> vector<[8]x[4]xf64>
96{
97  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[2]xf64> from vector<[8]xf64>
98  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][2] : vector<[2]xf64> from vector<[8]xf64>
99  // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[2]xf64> from vector<[8]xf64>
100  // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][6] : vector<[2]xf64> from vector<[8]xf64>
101  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[2]xf64> from vector<[4]xf64>
102  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][2] : vector<[2]xf64> from vector<[4]xf64>
103  // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
104  // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
105  // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
106  // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
107  // CHECK-DAG: %[[RES_4:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
108  // CHECK-DAG: %[[RES_5:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
109  // CHECK-DAG: %[[RES_6:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
110  // CHECK-DAG: %[[RES_7:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
111  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]], %[[RES_4]], %[[RES_5]], %[[RES_6]], %[[RES_7]] : vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>
112  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf64>, vector<[4]xf64>
113  return %0 : vector<[8]x[4]xf64>
114}
115
116// -----
117
118// CHECK-LABEL: @transfer_read_f32_scalable_8x8(
119// CHECK-SAME:                                  %[[SRC:.*]]: memref<?x?xi32>)
120// CHECK-SAME: -> (vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>)
121func.func @transfer_read_f32_scalable_8x8(%src: memref<?x?xi32>) -> vector<[8]x[8]xi32>
122{
123  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
124  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
125  // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
126  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
127  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
128  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
129  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
130  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
131  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
132  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
133  %c0 = arith.constant 0 : index
134  %pad = arith.constant 0 : i32
135  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[8]x[8]xi32>
136  return %0 : vector<[8]x[8]xi32>
137}
138
139// -----
140
141// CHECK-LABEL: @transfer_read_i16_scalable_8x16_masked(
142// CHECK-SAME:                                          %[[SRC:.*]]: memref<?x?xi16>,
143// CHECK-SAME:                                          %[[DIM0:.*]]: index,
144// CHECK-SAME:                                          %[[DIM1:.*]]: index)
145// CHECK-SAME: -> (vector<[8]x[8]xi16>, vector<[8]x[8]xi16>)
146func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0: index, %dim1: index) -> vector<[8]x[16]xi16>
147{
148  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
149  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
150  // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
151  // CHECK-DAG: %[[C0_I16:.*]] = arith.constant 0 : i16
152  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
153  // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
154  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
155  // CHECK-DAG: %[[RIGHT_DIM_1:.*]] = arith.addi %[[DIM1]], %[[MINUS_8_VSCALE]] : index
156  // CHECK-DAG: %[[LEFT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x[8]xi1>
157  // CHECK-DAG: %[[RIGHT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[RIGHT_DIM_1]] : vector<[8]x[8]xi1>
158  // CHECK-DAG: %[[LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I16]], %[[LEFT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
159  // CHECK-DAG: %[[RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[C0_I16]], %[[RIGHT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
160  // CHECK-NEXT: return %[[LEFT]], %[[RIGHT]] : vector<[8]x[8]xi16>, vector<[8]x[8]xi16>
161  %c0 = arith.constant 0 : index
162  %pad = arith.constant 0 : i16
163  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[16]xi1>
164  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[16]xi16>
165  return %0 : vector<[8]x[16]xi16>
166}
167
168// -----
169
170// CHECK-LABEL: @transfer_write_f16_scalable_16x8(
171// CHECK-SAME:                                    %[[DEST:.*]]: memref<?x?xf16>,
172// CHECK-SAME:                                    %[[TOP:.*]]: vector<[8]x[8]xf16>,
173// CHECK-SAME:                                    %[[BOTTOM:.*]]: vector<[8]x[8]xf16>)
174func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
175{
176  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
177  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
178  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
179  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
180  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
181  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
182  // CHECK-NEXT:   %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
183  // CHECK-NEXT:   vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
184  // CHECK-NEXT:   %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
185  // CHECK-NEXT:   %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
186  // CHECK-NEXT:   vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
187  // CHECK-NEXT: }
188  // CHECK-NEXT: return
189  %c0 = arith.constant 0 : index
190  vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
191  return
192}
193
194// -----
195
196/// This is already a legal type. It should be ignored.
197
198// CHECK-LABEL: @transfer_write_i8_scalable_16x16_masked
199func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec: vector<[16]x[16]xi8>, %dim0: index, %dim1: index)
200{
201  // CHECK: vector.transfer_write {{.*}} : vector<[16]x[16]xi8>, memref<?x?xi8>
202  %c0 = arith.constant 0 : index
203  %mask = vector.create_mask %dim0, %dim0 : vector<[16]x[16]xi1>
204  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
205  return
206}
207
208// -----
209
210// CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
211// CHECK-SAME:                                    %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
212// CHECK-SAME:                                    %[[DIM_0:[a-z0-9]+]]: index,
213// CHECK-SAME:                                    %[[DIM_1:[a-z0-9]+]]: index,
214// CHECK-SAME:                                    %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
215// CHECK-SAME:                                    %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
216// CHECK-SAME:                                    %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
217// CHECK-SAME:                                    %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
218func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
219{
220  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
221  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
222  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
223  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
224  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
225  // CHECK-DAG: %[[MASK:.*]] =  vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
226  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
227  // CHECK-NEXT:   %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
228  // CHECK-NEXT:   %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
229  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
230  // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
231  // CHECK-NEXT:   %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
232  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
233  // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
234  // CHECK-NEXT:   %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
235  // CHECK-NEXT:   %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
236  // CHECK-NEXT:   %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
237  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
238  // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
239  // CHECK-NEXT:   %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
240  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
241  // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
242  // CHECK-NEXT: }
243  %c0 = arith.constant 0 : index
244  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
245  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
246  return
247}
248
249// -----
250
251// Tensor semantics are not supported for the store loop lowering.
252
253// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
254// CHECK-NOT: scf.for
255func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %vec: vector<[8]x[8]xf32>)
256{
257  %c0 = arith.constant 0 : index
258  vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
259  return
260}
261
262// -----
263
264#transpose = affine_map<(d0, d1) -> (d1, d0)>
265
266// Transposes are not supported for the store loop lowering.
267
268// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
269// CHECK-NOT: scf.for
270func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
271{
272  %c0 = arith.constant 0 : index
273  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
274  vector.transfer_write %vec, %dest[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
275  return
276}
277
278// -----
279
280// Masked writes where any dimension of the mask is > 16 are not supported for the store loop lowering.
281
282// CHECK-LABEL: @negative_transfer_write_f32_scalable_32x32
283// CHECK-NOT: scf.for
284func.func @negative_transfer_write_f32_scalable_32x32(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[32]x[32]xf32>)
285{
286  %c0 = arith.constant 0 : index
287  %mask = vector.create_mask %dim0, %dim1 : vector<[32]x[32]xi1>
288  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[32]x[32]xf32>, memref<?x?xf32>
289  return
290}
291
292// -----
293
294#transpose = affine_map<(d0, d1) -> (d1, d0)>
295
296// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
297// CHECK-SAME:                                        %[[SRC:.*]]: memref<?x?xf32>,
298// CHECK-SAME:                                        %[[DEST:.*]]: memref<?x?xf32>)
299func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
300{
301  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
302  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
303  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
304  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
305  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
306  // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
307  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
308  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
309  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
310  // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
311  // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
312  // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
313  // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
314  // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
315  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
316  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
317  // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
318  // CHECK-NEXT:   %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
319  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
320  // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
321  // CHECK-NEXT:   %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
322  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
323  // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
324  // CHECK-NEXT:   %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
325  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
326  // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
327  // CHECK-NEXT: }
328  // CHECK-NEXT: return
329  %c0 = arith.constant 0 : index
330  %pad = arith.constant 0.0 : f32
331  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
332  vector.transfer_write %0, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
333  return
334}
335
336// -----
337
338#transpose = affine_map<(d0, d1) -> (d1, d0)>
339
340// CHECK-LABEL: @transpose_f32_scalable_4x16_via_write(
341// CHECK-SAME:                                         %[[SRC:.*]]: memref<?x?xf32>,
342// CHECK-SAME:                                         %[[DEST:.*]]: memref<?x?xf32>)
343func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
344{
345  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
346  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
347  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
348  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
349  // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
350  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
351  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
352  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
353  // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
354  // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
355  // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
356  // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
357  // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
358  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
359  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
360  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
361  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
362  // CHECK-NEXT: return
363  %c0 = arith.constant 0 : index
364  %pad = arith.constant 0.0 : f32
365  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[16]xf32>
366  vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
367  return
368}
369
370// -----
371
372// CHECK-LABEL: @extract_from_vector_create_mask_non_constant_dim(
373// CHECK-SAME:                                                    %[[DIM0:[a-z0-9]+]]: index,
374// CHECK-SAME:                                                    %[[DIM1:[a-z0-9]+]]: index,
375// CHECK-SAME:                                                    %[[DIM2:[a-z0-9]+]]: index)
376func.func @extract_from_vector_create_mask_non_constant_dim(%dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
377  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
378  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
379  // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi sgt, %[[DIM0]], %[[C2]] : index
380  // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
381  // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
382  // CHECK-NEXT: return %[[EXTRACT]]
383  %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
384  %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
385  return %extract : vector<[4]x[4]xi1>
386}
387
388// -----
389
390// CHECK-LABEL: @non_constant_extract_from_vector_create_mask_non_constant(
391// CHECK-SAME:                                                             %[[INDEX:[a-z0-9]+]]: index,
392// CHECK-SAME:                                                             %[[DIM0:[a-z0-9]+]]: index,
393// CHECK-SAME:                                                             %[[DIM1:[a-z0-9]+]]: index,
394// CHECK-SAME:                                                             %[[DIM2:[a-z0-9]+]]: index)
395func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: index, %dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
396  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
397  // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi slt, %[[INDEX]], %[[DIM0]] : index
398  // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
399  // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
400  // CHECK-NEXT: return %[[EXTRACT]]
401  %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
402  %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
403  return %extract : vector<[4]x[4]xi1>
404}
405
406// -----
407
408// CHECK-LABEL: @lift_illegal_transpose_to_memory(
409// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
410// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
411// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
412func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
413  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
414  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
415  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
416  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
417  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
418  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
419  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
420  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
421  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
422  // CHECK-NEXT: return %[[LEGAL_READ]]
423  %pad = arith.constant 0.0 : f32
424  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
425  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
426  return %legalType : vector<4x[8]xf32>
427}
428
429// -----
430
431// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_mask(
432// CHECK-SAME:                                              %[[DIM0:[a-z0-9]+]]: index,
433// CHECK-SAME:                                              %[[DIM1:[a-z0-9]+]]: index,
434// CHECK-SAME:                                              %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>
435func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index, %memref: memref<?x?xf32>, %a: index, %b: index) -> vector<4x[8]xf32> {
436  // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
437  // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
438  // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
439  // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
440  // CHECK:     %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
441  // CHECK-SAME:                       %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
442  // CHECK-NEXT: return %[[LEGAL_READ]]
443  %pad = arith.constant 0.0 : f32
444  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
445  %illegalRead = vector.transfer_read %memref[%a, %b], %pad, %mask : memref<?x?xf32>, vector<[8]x4xf32>
446  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
447  return %legalType : vector<4x[8]xf32>
448}
449
450// -----
451
452// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
453// CHECK-SAME:                                                     %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>
454func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
455  // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
456  // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
457  // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
458  // CHECK:     %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
459  // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
460  // CHECK-NEXT: return %[[EXT_TYPE]]
461  %pad = arith.constant 0 : i8
462  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xi8>, vector<[8]x4xi8>
463  %extRead = arith.extsi %illegalRead : vector<[8]x4xi8> to vector<[8]x4xi32>
464  %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
465  return %legalType : vector<4x[8]xi32>
466}
467
468// -----
469
470// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_in_bounds_attr
471func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
472  // CHECK: vector.transfer_read
473  // CHECK-SAME: in_bounds = [true, false]
474  // CHECK-NOT: in_bounds = [false, true]
475  %pad = arith.constant 0.0 : f32
476  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[8]x4xf32>
477  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
478  return %legalType : vector<4x[8]xf32>
479}
480
481// -----
482
483// The pass should do nothing (and not crash).
484// CHECK-LABEL: @illegal_transpose_no_defining_source_op
485func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
486{
487  // CHECK: vector.transpose
488  %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
489  return %0 : vector<1x[4]xf32>
490}
491
492// -----
493
494// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
495// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
496func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
497  // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
498  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
499  return %0 : vector<1x[4]xf32>
500}
501
502// -----
503
504// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
505// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
506func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
507  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
508  // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
509  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
510  return %0 : vector<[4]xf32>
511}
512
513// -----
514
515// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
516func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
517  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
518  // CHECK-NOT: vector.shape_cast
519  %pad = arith.constant 0.0 : f32
520  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
521  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
522  return %cast : vector<1x[4]xf32>
523}
524
525// -----
526
527// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
528func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
529  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
530  // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
531  %pad = arith.constant 0.0 : f32
532  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
533  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
534  return %cast : vector<[4]xf32>
535}
536
537// -----
538
539// CHECK-LABEL: @multi_tile_splat
540func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
541{
542  // CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32>
543  // CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
544  %0 = arith.constant dense<42> : vector<[8]x[8]xi32>
545  return %0 : vector<[8]x[8]xi32>
546}
547
548// -----
549
550// CHECK: #[[$TRANSPOSE_MAP_0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
551
552// CHECK-LABEL: @transpose_store_scalable_via_za(
553// CHECK-SAME:                                   %[[VEC:.*]]: vector<2x[4]xf32>
554// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf32>,
555// CHECK-SAME:                                   %[[I:.*]]: index,
556// CHECK-SAME:                                   %[[J:.*]]: index)
557func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
558  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
559  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
560  // CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
561  // CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32>
562  // CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
563  // CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32>
564  // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
565  // CHECK-NEXT: %[[VSCALE:.*]] = vector.vscale
566  // CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
567  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1>
568  // CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32>
569  %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
570  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>,  memref<?x?xf32>
571  return
572}
573
574// -----
575
576// CHECK-LABEL: @transpose_store_scalable_via_za_masked(
577// CHECK-SAME:                                          %[[A:[a-z0-9]+]]: index,
578// CHECK-SAME:                                          %[[B:[a-z0-9]+]]: index)
579func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) {
580  // CHECK: %[[C2:.*]] = arith.constant 2 : index
581  // CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]]
582  // CHECK: %[[MASK:.*]] = vector.create_mask %[[A]], %[[MIN]] : vector<[4]x[4]xi1>
583  // CHECK: vector.transfer_write {{.*}} %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
584  %c0 = arith.constant 0 : index
585  %mask = vector.create_mask %a, %b : vector<[4]x2xi1>
586  %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
587  vector.transfer_write %tr, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x2xf32>,  memref<?x?xf32>
588  return
589}
590
591// -----
592
593// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile(
594// CHECK-SAME:                                              %[[VEC:.*]]: vector<8x[4]xf32>
595// CHECK-SAME:                                              %[[DEST:.*]]: memref<?x?xf32>,
596// CHECK-SAME:                                              %[[I:.*]]: index,
597// CHECK-SAME:                                              %[[J:.*]]: index)
598func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
599  // CHECK: %[[C4:.*]] = arith.constant 4 : index
600
601  // <skip 3x other extract+insert chain>
602  // CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32>
603  // CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
604  // CHECK: %[[VSCALE:.*]] = vector.vscale
605  // CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
606  // CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1>
607  // CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
608
609  // <skip 3x other extract+insert chain>
610  // CHECK: %[[V7:.*]] = vector.extract %arg0[7] : vector<[4]xf32> from vector<8x[4]xf32>
611  // CHECK: %[[TILE_1:.*]] = vector.insert %[[V7]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
612  // CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[C4]] : index
613  // CHECK: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[I]], %[[J_OFFSET]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
614  %tr = vector.transpose %vec, [1, 0] : vector<8x[4]xf32> to vector<[4]x8xf32>
615  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>,  memref<?x?xf32>
616  return
617}
618
619// -----
620
621// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile_wide
622func.func @transpose_store_scalable_via_za_multi_tile_wide(%vec: vector<2x[8]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
623  // <check extracts from lower 4 x vscale of %vec>
624  // CHECK: vector.scalable.extract
625  // CHECK: %[[ROW_2_LOWER:.*]] = vector.scalable.extract %{{.*}}[0] : vector<[4]xf32> from vector<[8]xf32>
626  // CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_LOWER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
627  // CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I:.[a-z0-9]+]], %[[J:[a-z0-9]+]]]
628
629  // <check extracts from upper 4 x vscale of %vec>
630  // CHECK: vector.scalable.extract
631  // CHECK: %[[ROW_2_UPPER:.*]] = vector.scalable.extract %{{.*}}[4] : vector<[4]xf32> from vector<[8]xf32>
632  // CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_UPPER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
633  // CHECK: %[[I_OFFSET:.*]] = arith.addi %c4_vscale, %[[I]] : index
634  // CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I_OFFSET]], %[[J]]]
635  %tr = vector.transpose %vec, [1, 0] : vector<2x[8]xf32> to vector<[8]x2xf32>
636  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[8]x2xf32>,  memref<?x?xf32>
637  return
638}
639
640// -----
641
642// CHECK-LABEL: @negative_transpose_store_scalable_via_za__bad_source_shape
643// CHECK-NOT: arm_sme.get_tile
644func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vector<2x[7]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
645  %tr = vector.transpose %vec, [1, 0] : vector<2x[7]xf32> to vector<[7]x2xf32>
646  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>,  memref<?x?xf32>
647  return
648}
649
650// -----
651
652// From: https://github.com/llvm/llvm-project/issues/118449.
653// Check -arm-sme-vector-legalization does not crash when it encounters a `vector.mask` that
654// does not contain a maskable op.
655func.func @vector_mask_without_maskable_op(%mask: vector<16x2xi1>, %vec: vector<16x16xf32>) -> vector<16x16xf32> {
656  %0 = vector.mask %mask { vector.yield %vec : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32>
657  return %0 : vector<16x16xf32>
658}
659