xref: /llvm-project/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (revision a7a4c16c672bdd8e245af533a1f170522e26e42a)
1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2
3// CHECK-LABEL: func @nop_shape_cast
4// CHECK-SAME: %[[A:.*]]: vector<16xf32>
5// CHECK:      return %[[A]] : vector<16xf32>
6func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
7  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
8  return %0 : vector<16xf32>
9}
10
11// CHECK-LABEL: func @cancel_shape_cast
12// CHECK-SAME: %[[A:.*]]: vector<16xf32>
13// CHECK:      return %[[A]] : vector<16xf32>
14
15func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
16  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
17  %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
18  return %1 : vector<16xf32>
19}
20
21// Shape up and downcasts for 2-D vectors, for supporting conversion to
22// llvm.matrix operations
23// CHECK-LABEL: func @shape_casts
24func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
25  // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
26  // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
27  // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32>
28  //
29  // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
30  // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
31  //
32  // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
33  //
34  // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
35  // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
36  //
37  %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
38  // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32>
39  %r0 = arith.addf %0, %0: vector<4xf32>
40  //
41  // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]]
42  // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
43  // CHECK-SAME: vector<4xf32> to vector<2xf32>
44  //
45  // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] :
46  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
47  //
48  // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]]
49  // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
50  // CHECK-SAME: vector<4xf32> to vector<2xf32>
51  //
52  // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
53  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
54  //
55  %1 = vector.shape_cast %r0  : vector<4xf32> to vector<2x2xf32>
56  // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
57  return %r0, %1 : vector<4xf32>, vector<2x2xf32>
58}
59
60// CHECK-LABEL: func @shape_cast_2d2d
61// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
62// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
63// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
64// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
65// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
66// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
67// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
68// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
69// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
70// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
71// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
72// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
73// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
74// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
75// CHECK: return %[[T11]] : vector<2x3xf32>
76
77func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
78  %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
79  return %s : vector<2x3xf32>
80}
81
82// CHECK-LABEL: func @shape_cast_3d1d
83// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
84// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
85// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
86// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]]
87// CHECK-SAME:           {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
88// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
89// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
90// CHECK-SAME:           {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
91// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
92// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
93// CHECK-SAME:           {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
94// CHECK: return %[[T5]] : vector<6xf32>
95
96func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
97  %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
98  return %s : vector<6xf32>
99}
100
101// CHECK-LABEL: func @shape_cast_1d3d
102// CHECK-SAME: %[[A:.*]]: vector<6xf32>
103// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
104// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
105// CHECK-SAME:           {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
106// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
107// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
108// CHECK:                {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
109// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32>
110// CHECK: return %[[T3]] : vector<2x1x3xf32>
111
112func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
113  %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
114  return %s : vector<2x1x3xf32>
115}
116
117// CHECK-LABEL:   func.func @shape_cast_0d1d(
118// CHECK-SAME:                               %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> {
119// CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
120// CHECK:           %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32>
121// CHECK:           %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32>
122// CHECK:           return %[[VAL_3]] : vector<1xf32>
123// CHECK:         }
124
125func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
126  %s = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
127  return %s : vector<1xf32>
128}
129
130// CHECK-LABEL:   func.func @shape_cast_1d0d(
131// CHECK-SAME:                               %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> {
132// CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
133// CHECK:           %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<1xf32>
134// CHECK:           %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32>
135// CHECK:           return %[[VAL_3]] : vector<f32>
136// CHECK:         }
137
138func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
139  %s = vector.shape_cast %arg0 : vector<1xf32> to vector<f32>
140  return %s : vector<f32>
141}
142
143module attributes {transform.with_named_sequence} {
144  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
145    %f = transform.structured.match ops{["func.func"]} in %module_op
146      : (!transform.any_op) -> !transform.any_op
147
148    transform.apply_patterns to %f {
149      transform.apply_patterns.vector.lower_shape_cast
150    } : !transform.any_op
151    transform.yield
152  }
153}
154