xref: /llvm-project/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir (revision e4384149b58f7c3d19c5d38bc46038c660b77ca9)
1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2
3/// This tests that shape casts of scalable vectors (with one trailing scalable dim)
4/// can be correctly lowered to vector.scalable.insert/extract.
5
6// CHECK-LABEL: i32_3d_to_1d_last_dim_scalable
7// CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32>
8func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32>
9{
10  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32>
11  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
12  // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[4]xi32> into vector<[8]xi32>
13  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
14  // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][4] : vector<[4]xi32> into vector<[8]xi32>
15  %flat = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
16  // CHECK-NEXT: return %[[res1]] : vector<[8]xi32>
17  return %flat : vector<[8]xi32>
18}
19
20// -----
21
22// CHECK-LABEL: i32_1d_to_3d_last_dim_scalable
23// CHECK-SAME: %[[arg0:.*]]: vector<[8]xi32>
24func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x[4]xi32> {
25  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<2x1x[4]xi32>
26  // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[4]xi32> from vector<[8]xi32>
27  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
28  // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][4] : vector<[4]xi32> from vector<[8]xi32>
29  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
30  %unflat = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
31  // CHECK-NEXT: return %[[res1]] : vector<2x1x[4]xi32>
32  return %unflat : vector<2x1x[4]xi32>
33}
34
35// -----
36
37// CHECK-LABEL: i8_2d_to_1d_last_dim_scalable
38// CHECK-SAME: %[[arg0:.*]]: vector<4x[8]xi8>
39func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]xi8> {
40  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[32]xi8>
41  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[8]xi8> from vector<4x[8]xi8>
42  // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[8]xi8> into vector<[32]xi8>
43  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[8]xi8> from vector<4x[8]xi8>
44  // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][8] : vector<[8]xi8> into vector<[32]xi8>
45  // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][2] : vector<[8]xi8> from vector<4x[8]xi8>
46  // CHECK-NEXT: %[[res2:.*]] = vector.scalable.insert %[[subvec2]], %[[res1]][16] : vector<[8]xi8> into vector<[32]xi8>
47  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][3] : vector<[8]xi8> from vector<4x[8]xi8>
48  // CHECK-NEXT: %[[res3:.*]] = vector.scalable.insert %[[subvec3]], %[[res2]][24] : vector<[8]xi8> into vector<[32]xi8>
49  %flat = vector.shape_cast %arg0 : vector<4x[8]xi8> to vector<[32]xi8>
50  // CHECK-NEXT: return %[[res3]] : vector<[32]xi8>
51  return %flat : vector<[32]xi8>
52}
53
54// -----
55
56// CHECK-LABEL: i8_1d_to_2d_last_dim_scalable
57// CHECK-SAME: %[[arg0:.*]]: vector<[32]xi8>
58func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]xi8> {
59  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<4x[8]xi8>
60  // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[8]xi8> from vector<[32]xi8>
61  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[8]xi8> into vector<4x[8]xi8>
62  // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][8] : vector<[8]xi8> from vector<[32]xi8>
63  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[8]xi8> into vector<4x[8]xi8>
64  // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[arg0]][16] : vector<[8]xi8> from vector<[32]xi8>
65  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[8]xi8> into vector<4x[8]xi8>
66  // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[arg0]][24] : vector<[8]xi8> from vector<[32]xi8>
67  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[8]xi8> into vector<4x[8]xi8>
68  %unflat = vector.shape_cast %arg0 : vector<[32]xi8> to vector<4x[8]xi8>
69  // CHECK-NEXT: return %[[res3]] : vector<4x[8]xi8>
70  return %unflat : vector<4x[8]xi8>
71}
72
73// -----
74
75// CHECK-LABEL: f32_permute_leading_non_scalable_dims
76// CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
77func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
78  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x[4]xf32>
79  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
80  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
81  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
82  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
83  // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
84  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
85  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
86  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
87  // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
88  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
89  // CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
90  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
91  %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32>
92  // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32>
93  return %res : vector<3x2x[4]xf32>
94}
95
96// -----
97
98// CHECK-LABEL: f64_flatten_leading_non_scalable_dims
99// CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf64>
100func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) -> vector<4x[2]xf64>
101{
102  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4x[2]xf64>
103  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[2]xf64> from vector<2x2x[2]xf64>
104  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf64> into vector<4x[2]xf64>
105  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[2]xf64> from vector<2x2x[2]xf64>
106  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf64> into vector<4x[2]xf64>
107  // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][1, 0] : vector<[2]xf64> from vector<2x2x[2]xf64>
108  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf64> into vector<4x[2]xf64>
109  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 1] : vector<[2]xf64> from vector<2x2x[2]xf64>
110  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf64> into vector<4x[2]xf64>
111  %res = vector.shape_cast %arg0: vector<2x2x[2]xf64> to vector<4x[2]xf64>
112  // CHECK-NEXT: return %7 : vector<4x[2]xf64>
113  return %res : vector<4x[2]xf64>
114}
115
116// -----
117
118// CHECK-LABEL: f32_reduce_trailing_scalable_dim
119// CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
120func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
121{
122  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<6x[2]xf32>
123  // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
124  // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
125  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
126  // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32>
127  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32>
128  // CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32>
129  // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32>
130  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32>
131  // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32>
132  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32>
133  // CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32>
134  // CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32>
135  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32>
136  // CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32>
137  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32>
138  %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32>
139  // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32>
140  return %res: vector<6x[2]xf32>
141}
142
143// -----
144
145// CHECK-LABEL: f32_increase_trailing_scalable_dim
146// CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32>
147func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32>
148{
149  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<2x[4]xf32>
150  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[2]xf32> from vector<4x[2]xf32>
151  // CHECK-NEXT: %[[resvec0:.*]] = vector.extract %[[cst]][0] : vector<[4]xf32> from vector<2x[4]xf32>
152  // CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[resvec0]][0] : vector<[2]xf32> into vector<[4]xf32>
153  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[2]xf32> from vector<4x[2]xf32>
154  // CHECK-NEXT: %[[resvec2:.*]] = vector.scalable.insert %[[subvec1]], %[[resvec1]][2] : vector<[2]xf32> into vector<[4]xf32>
155  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[resvec2]], %[[cst]] [0] : vector<[4]xf32> into vector<2x[4]xf32>
156  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][2] : vector<[2]xf32> from vector<4x[2]xf32>
157  // CHECK-NEXT: %[[resvec3:.*]] = vector.extract %[[cst]][1] : vector<[4]xf32> from vector<2x[4]xf32>
158  // CHECK-NEXT: %[[resvec4:.*]] = vector.scalable.insert %[[subvec3]], %[[resvec3]][0] : vector<[2]xf32> into vector<[4]xf32>
159  // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][3] : vector<[2]xf32> from vector<4x[2]xf32>
160  // CHECK-NEXT: %[[resvec5:.*]] = vector.scalable.insert %[[subvec4]], %[[resvec4]][2] : vector<[2]xf32> into vector<[4]xf32>
161  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[resvec5]], %[[res0]] [1] : vector<[4]xf32> into vector<2x[4]xf32>
162  %res = vector.shape_cast %arg0: vector<4x[2]xf32> to vector<2x[4]xf32>
163  // CHECK-NEXT: return %[[res1]] : vector<2x[4]xf32>
164  return %res: vector<2x[4]xf32>
165}
166
167// -----
168
169/// The following shape_casts are not supported as the types cannot be
170/// represented in LLVM (and likely won't be supported soon), and currently
171/// there's no ops that could do the extracts/inserts required.
172
173// -----
174
175// CHECK-LABEL: cannot_cast_to_non_trailing_scalable_dim
176// CHECK-SAME: %[[arg0:.*]]: vector<[4]xf32>
177func.func @cannot_cast_to_non_trailing_scalable_dim(%arg0: vector<[4]xf32>) -> vector<[2]x2xf32> {
178  // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[4]xf32> to vector<[2]x2xf32>
179  %res = vector.shape_cast %arg0 : vector<[4]xf32> to vector<[2]x2xf32>
180  // CHECK-NEXT: return %[[res]] : vector<[2]x2xf32>
181  return %res: vector<[2]x2xf32>
182}
183
184// -----
185
186// CHECK-LABEL: cannot_shape_cast_from_non_trailing_scalable_dim
187// CHECK-SAME: %[[arg0:.*]]: vector<[2]x2xf32>
188func.func @cannot_shape_cast_from_non_trailing_scalable_dim(%arg0: vector<[2]x2xf32>) -> vector<[4]xf32> {
189  // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[2]x2xf32> to vector<[4]xf32>
190  %res = vector.shape_cast %arg0 : vector<[2]x2xf32> to vector<[4]xf32>
191  // CHECK-NEXT: return %[[res]] : vector<[4]xf32>
192  return %res: vector<[4]xf32>
193}
194
195// -----
196
197// CHECK-LABEL: cannot_shape_cast_more_than_one_scalable_dim
198// CHECK-SAME: %[[arg0:.*]]: vector<[4]x[4]xf32>
199func.func @cannot_shape_cast_more_than_one_scalable_dim(%arg0: vector<[4]x[4]xf32>) -> vector<2x[2]x[4]xf32>  {
200  // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[4]x[4]xf32> to vector<2x[2]x[4]xf32>
201  %res = vector.shape_cast %arg0 : vector<[4]x[4]xf32> to vector<2x[2]x[4]xf32>
202  // CHECK-NEXT: return %[[res]] : vector<2x[2]x[4]xf32>
203  return %res: vector<2x[2]x[4]xf32>
204}
205
206module attributes {transform.with_named_sequence} {
207  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
208    %f = transform.structured.match ops{["func.func"]} in %module_op
209      : (!transform.any_op) -> !transform.any_op
210
211    transform.apply_patterns to %f {
212      transform.apply_patterns.vector.lower_shape_cast
213    } : !transform.any_op
214    transform.yield
215  }
216}
217