xref: /llvm-project/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (revision 9816edc9f3ce198d41e364dd3467caa839a0c220)
1// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s
2// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering=allow-multiple-uses -split-input-file | FileCheck %s --check-prefix=MULTIUSE
3
4// CHECK-LABEL: func @transfer_read_0d(
5//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
6//       CHECK:   %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[idx]]]
7//       CHECK:   return %[[r]]
8func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
9  %cst = arith.constant 0.0 : f32
10  %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
11  %1 = vector.extractelement %0[] : vector<f32>
12  return %1 : f32
13}
14
15// -----
16
17//       CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
18// CHECK-LABEL: func @transfer_read_1d(
19//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[idx2:.*]]: index
20//       CHECK:   %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]], %[[idx2]]]
21//       CHECK:   %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]]]
22//       CHECK:   return %[[r]]
23func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) -> f32 {
24  %cst = arith.constant 0.0 : f32
25  %c0 = arith.constant 0 : index
26  %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
27  %1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
28  return %1 : f32
29}
30
31// -----
32
33// CHECK-LABEL: func @tensor_transfer_read_0d(
34//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index
35//       CHECK:   %[[r:.*]] = tensor.extract %[[t]][%[[idx]], %[[idx]], %[[idx]]]
36//       CHECK:   return %[[r]]
37func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
38  %cst = arith.constant 0.0 : f32
39  %0 = vector.transfer_read %t[%idx, %idx, %idx], %cst : tensor<?x?x?xf32>, vector<f32>
40  %1 = vector.extractelement %0[] : vector<f32>
41  return %1 : f32
42}
43
44// -----
45
46// CHECK-LABEL: func @transfer_write_0d(
47//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
48//       CHECK:   %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
49//       CHECK:   %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
50//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
51func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
52  %0 = vector.broadcast %f : f32 to vector<f32>
53  vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
54  return
55}
56
57// -----
58
59// CHECK-LABEL: func @transfer_write_1d(
60//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
61//       CHECK:   memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
62func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
63  %0 = vector.broadcast %f : f32 to vector<1xf32>
64  vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<1xf32>, memref<?x?x?xf32>
65  return
66}
67
68// -----
69
70// CHECK-LABEL: func @tensor_transfer_write_0d(
71//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
72//       CHECK:   %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
73//       CHECK:   %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
74//       CHECK:   %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
75//       CHECK:   return %[[r]]
76func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
77  %0 = vector.broadcast %f : f32 to vector<f32>
78  %1 = vector.transfer_write %0, %t[%idx, %idx, %idx] : vector<f32>, tensor<?x?x?xf32>
79  return %1 : tensor<?x?x?xf32>
80}
81
82// -----
83
84//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 8)>
85//       CHECK: #[[$map1:.*]] = affine_map<()[s0] -> (s0 + 1)>
86// CHECK-LABEL: func @transfer_read_2d_extract(
87//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?x?xf32>, %[[idx:.*]]: index, %[[idx2:.*]]: index
88//       CHECK:   %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]]]
89//       CHECK:   %[[added1:.*]] = affine.apply #[[$map1]]()[%[[idx]]]
90//       CHECK:   %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]], %[[added1]]]
91//       CHECK:   return %[[r]]
92func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2: index) -> f32 {
93  %cst = arith.constant 0.0 : f32
94  %c0 = arith.constant 0 : index
95  %0 = vector.transfer_read %m[%idx, %idx, %idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?x?x?xf32>, vector<10x5xf32>
96  %1 = vector.extract %0[8, 1] : f32 from vector<10x5xf32>
97  return %1 : f32
98}
99
100// -----
101
102// CHECK-LABEL: func @transfer_write_arith_constant(
103//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
104//       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
105//       CHECK:   %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
106//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
107func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
108  %cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
109  vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
110  return
111}
112
113// -----
114
115// CHECK-LABEL: func @transfer_read_multi_use(
116//  CHECK-SAME:   %[[m:.*]]: memref<?xf32>, %[[idx:.*]]: index
117//   CHECK-NOT:   memref.load
118//       CHECK:   %[[r:.*]] = vector.transfer_read %[[m]][%[[idx]]]
119//       CHECK:   %[[e0:.*]] = vector.extract %[[r]][0]
120//       CHECK:   %[[e1:.*]] = vector.extract %[[r]][1]
121//       CHECK:   return %[[e0]], %[[e1]]
122
123// MULTIUSE-LABEL: func @transfer_read_multi_use(
124//  MULTIUSE-SAME:   %[[m:.*]]: memref<?xf32>, %[[idx0:.*]]: index
125//   MULTIUSE-NOT:   vector.transfer_read
126//       MULTIUSE:   %[[r0:.*]] = memref.load %[[m]][%[[idx0]]
127//       MULTIUSE:   %[[idx1:.*]] = affine.apply
128//       MULTIUSE:   %[[r1:.*]] = memref.load %[[m]][%[[idx1]]
129//       MULTIUSE:   return %[[r0]], %[[r1]]
130
131func.func @transfer_read_multi_use(%m: memref<?xf32>, %idx: index) -> (f32, f32) {
132  %cst = arith.constant 0.0 : f32
133  %0 = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref<?xf32>, vector<16xf32>
134  %1 = vector.extract %0[0] : f32 from vector<16xf32>
135  %2 = vector.extract %0[1] : f32 from vector<16xf32>
136  return %1, %2 : f32, f32
137}
138
139// -----
140
141// Check that patterns don't trigger for an sub-vector (not scalar) extraction.
142// CHECK-LABEL: func @subvector_extract(
143//  CHECK-SAME:   %[[m:.*]]: memref<?x?xf32>, %[[idx:.*]]: index
144//   CHECK-NOT:   memref.load
145//       CHECK:   %[[r:.*]] = vector.transfer_read %[[m]][%[[idx]], %[[idx]]]
146//       CHECK:   %[[e0:.*]] = vector.extract %[[r]][0]
147//       CHECK:   return %[[e0]]
148
149func.func @subvector_extract(%m: memref<?x?xf32>, %idx: index) -> vector<16xf32> {
150  %cst = arith.constant 0.0 : f32
151  %0 = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32>
152  %1 = vector.extract %0[0] : vector<16xf32> from vector<8x16xf32>
153  return %1 : vector<16xf32>
154}
155
156