xref: /llvm-project/mlir/test/Conversion/VectorToArmSME/unsupported.mlir (revision e2296d8295516e9991cd6ca99ba193fbd232b6da)
1// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
2
3//===----------------------------------------------------------------------===//
4// vector.transfer_read
5//===----------------------------------------------------------------------===//
6
7// CHECK-LABEL: @transfer_read_2d__bad_type
8// CHECK-NOT: arm_sme.tile_load
9// CHECK: vector.transfer_read
10func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
11  %c0 = arith.constant 0 : index
12  %pad = arith.constant 0.0 : f64
13  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
14  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
15  return
16}
17
18// -----
19
20// CHECK-LABEL: @transfer_read_2d__non_memref_type
21// CHECK-NOT: arm_sme.tile_load
22// CHECK: vector.transfer_read
23func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
24  %c0 = arith.constant 0 : index
25  %pad = arith.constant 0.0 : f64
26  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
27  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
28  return
29}
30
31// -----
32
33// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
34// CHECK-NOT: arm_sme.tile_load
35// CHECK: vector.transfer_read
36func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
37  %c0 = arith.constant 0 : index
38  %pad = arith.constant 0.0 : f64
39  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
40  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
41  return
42}
43
44// -----
45
46// CHECK-LABEL: @transfer_read_2d__non_transpose
47// CHECK-NOT: arm_sme.tile_load
48// CHECK: vector.transfer_read
49func.func @transfer_read_2d__non_transpose(%src : memref<?x?xf64>) {
50  %c0 = arith.constant 0 : index
51  %pad = arith.constant 0.0 : f64
52  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, 0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
53  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
54  return
55}
56
57// -----
58
59// CHECK-LABEL: @transfer_read_2d__out_of_bounds
60// CHECK-NOT: arm_sme.tile_load
61// CHECK: vector.transfer_read
62func.func @transfer_read_2d__out_of_bounds(%src : memref<?x?xf64>) {
63  %c0 = arith.constant 0 : index
64  %pad = arith.constant 0.0 : f64
65  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[2]x[2]xf64>
66  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
67  return
68}
69
70//===----------------------------------------------------------------------===//
71// vector.transfer_write
72//===----------------------------------------------------------------------===//
73
74// -----
75
76// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
77// lowering only occurs for vector types of correct rank, shape, element size
78// and number of scalable dims.
79
80// CHECK-LABEL: @transfer_write_2d_zero__bad_type
81// CHECK: vector.transfer_write
82// CHECK-NOT: arm_sme.intr.zero
83func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) {
84  %c0 = arith.constant 0 : index
85  %cst = arith.constant dense<0> : vector<[16]x[16]xi4>
86  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4>
87  return
88}
89
90// -----
91
92// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
93// CHECK: vector.transfer_write
94// CHECK-NOT: arm_sme.tile_store
95func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
96  %c0 = arith.constant 0 : index
97  %cst = arith.constant dense<0> : vector<[8]x[8]xi8>
98  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8>
99  return
100}
101
102// -----
103
104// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
105// CHECK: vector.transfer_write
106// CHECK-NOT: arm_sme.tile_store
107func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
108  %c0 = arith.constant 0 : index
109  %cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8>
110  vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8>
111  return
112}
113
114// -----
115
116// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type
117// CHECK: vector.transfer_write
118// CHECK-NOT: arm_sme.tile_store
119func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> {
120  %c0 = arith.constant 0 : index
121  %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
122  %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
123  return %0 : tensor<?x?xi8>
124}
125
126// -----
127
128// CHECK-LABEL: @transfer_write_2d__fixed
129// CHECK: vector.transfer_write
130// CHECK-NOT: arm_sme.tile_store
131func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) {
132  %c0 = arith.constant 0 : index
133  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8>
134  return
135}
136
137// -----
138
139// CHECK-LABEL: @transfer_write_2d__out_of_bounds
140// CHECK: vector.transfer_write
141// CHECK-NOT: arm_sme.tile_store
142func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
143  %c0 = arith.constant 0 : index
144  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32>
145  return
146}
147
148// -----
149
150// CHECK-LABEL: func.func @transfer_write_slice_unsupported_permutation
151// CHECK-NOT: arm_sme.store_tile_slice
152func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
153  %c0 = arith.constant 0 : index
154  %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
155  vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
156  return
157}
158
159
160//===----------------------------------------------------------------------===//
161// vector.outerproduct
162//===----------------------------------------------------------------------===//
163
164// -----
165
166// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
167// CHECK-NOT: arm_sme.outerproduct
168// CHECK:     vector.outerproduct
169func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
170  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
171  return %0 : vector<[2]xf64>
172}
173
174// -----
175
176// CHECK-LABEL: @vector_outerproduct_unsupported_kind
177// CHECK-NOT: arm_sme.outerproduct
178// CHECK:     vector.outerproduct
179func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
180  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
181  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
182  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
183}
184
185// -----
186
187// CHECK-LABEL: @vector_outerproduct_unknown_mask
188// CHECK-NOT: arm_sme.outerproduct
189// CHECK:     vector.outerproduct
190func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
191  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
192  %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
193  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
194}
195
196// -----
197
198/// Not SVE predicate-sized.
199
200// CHECK-LABEL: @negative_vector_extract_to_psel_0
201func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1>
202{
203  // CHECK-NOT: arm_sve.psel
204  %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1>
205  %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1>
206  return %slice : vector<[32]xi1>
207}
208
209// -----
210
211/// Source not 2-D scalable mask.
212
213// CHECK-LABEL: @negative_vector_extract_to_psel_1
214func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1>
215{
216  // CHECK-NOT: arm_sve.psel
217  %mask = vector.create_mask %a, %b : vector<4x[8]xi1>
218  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1>
219  return %slice : vector<[8]xi1>
220}
221
222// -----
223
224/// Source not vector.create_mask.
225
226// CHECK-LABEL: @negative_vector_extract_to_psel_2
227func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1>
228{
229  // CHECK-NOT: arm_sve.psel
230  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
231  return %slice : vector<[8]xi1>
232}
233
234// -----
235
236/// Not psel-like extract.
237
238// CHECK-LABEL: @negative_vector_extract_to_psel_3
239func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1
240{
241  // CHECK-NOT: arm_sve.psel
242  %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
243  %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1>
244  return %el : i1
245}
246