xref: /llvm-project/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (revision f5aee1f18bdbc5694330a5e86eb46cf60e653d0c)
1// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
2// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32
3
4// Expect no conversions.
5func.func @memref_i8() -> i8 {
6    %c3 = arith.constant 3 : index
7    %m = memref.alloc() : memref<4xi8, 1>
8    %v = memref.load %m[%c3] : memref<4xi8, 1>
9    memref.dealloc %m : memref<4xi8, 1>
10    return %v : i8
11}
12// CHECK-LABEL: func @memref_i8()
13//       CHECK:   %[[M:.+]] = memref.alloc() : memref<4xi8, 1>
14//  CHECK-NEXT:   %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1>
15//  CHECK-NEXT:   memref.dealloc %[[M]]
16//  CHECK-NEXT:   return %[[V]]
17
18// CHECK32-LABEL: func @memref_i8()
19//       CHECK32:   %[[M:.+]] = memref.alloc() : memref<1xi32, 1>
20//       CHECK32:   %[[C0:.+]] = arith.constant 0 : index
21//       CHECK32:   %[[V:.+]] = memref.load %[[M]][%[[C0]]] : memref<1xi32, 1>
22//       CHECK32:   %[[C24:.+]] = arith.constant 24 : index
23//       CHECK32:   %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32
24//       CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]]
25//       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8
26//  CHECK32-NEXT:   memref.dealloc %[[M]]
27//  CHECK32-NEXT:   return %[[TRUNC]]
28
29// -----
30
31func.func @memref_load_i4(%arg0: index) -> i4 {
32    %0 = memref.alloc() : memref<5xi4>
33    %1 = memref.load %0[%arg0] : memref<5xi4>
34    return %1 : i4
35}
36//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
37//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
38//      CHECK: func @memref_load_i4(
39// CHECK-SAME:     %[[ARG0:.+]]: index
40//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
41//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
42//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
43//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
44//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
45//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
46//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
47//      CHECK:   return %[[TRUNC]]
48
49//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
50//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
51//      CHECK32: func @memref_load_i4(
52// CHECK32-SAME:     %[[ARG0:.+]]: index
53//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
54//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
55//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
56//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
57//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
58//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
59//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
60//      CHECK32:   return %[[TRUNC]]
61
62// -----
63
64func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
65    %0 = memref.alloc() : memref<3x125xi4>
66    memref.assume_alignment %0, 64 : memref<3x125xi4>
67    %1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
68    return %1 : i4
69}
70//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
71//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)
72//      CHECK: func @memref_load_i4_rank2(
73// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
74// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
75//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
76//      CHECK:   memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
77//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
78//      CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
79//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
80//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
81//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
82//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
83//      CHECK:   return %[[TRUNC]]
84
85//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
86//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)
87//      CHECK32: func @memref_load_i4_rank2(
88// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
89// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
90//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
91//      CHECK32:   memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
92//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
93//      CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
94//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
95//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
96//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
97//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
98//      CHECK32:   return %[[TRUNC]]
99
100// -----
101
102func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index) -> i4 {
103  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
104  %1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
105  return %1 : i4
106}
107//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
108//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
109//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
110//      CHECK: func @memref_load_i4_dynamic(
111// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
112// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
113// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
114// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
115//      CHECK:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
116//      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
117//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
118//      CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
119//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
120//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
121//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
122//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
123//      CHECK:   return %[[TRUNC]]
124
125//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
126//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
127//  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
128//      CHECK32: func @memref_load_i4_dynamic(
129// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
130// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
131// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
132// CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
133//      CHECK32:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
134//      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
135//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
136//      CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
137//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
138//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
139//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
140//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
141//      CHECK32:   return %[[TRUNC]]
142
143// -----
144
145func.func @rank_zero_memref() -> i4 {
146  %0 = memref.alloc() : memref<i4>
147  %1 = memref.load %0[] : memref<i4>
148  return %1 : i4
149}
150// CHECK-LABEL: func @rank_zero_memref()
151//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<i8>
152//       CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i8>
153//       CHECK:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i8 to i4
154//       CHECK:   return %[[TRUNC]]
155
156// CHECK32-LABEL: func @rank_zero_memref()
157//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<i32>
158//       CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i32>
159//       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
160//       CHECK32:   return %[[TRUNC]]
161
162// -----
163
164func.func @memref_strided_i4(%idx : index) -> i4 {
165  %arr = memref.alloc() : memref<128xi4>
166  %subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset:32>>
167  %1 = memref.load %subview[%idx] : memref<32xi4, strided<[1], offset:32>>
168  return %1 : i4
169}
170
171// CHECK-LABEL: func @memref_strided_i4
172//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
173//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: 16>>
174//       CHECK:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
175
176// CHECK32-LABEL: func @memref_strided_i4
177//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
178//       CHECK32:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
179//       CHECK32:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
180
181// -----
182
183func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
184  %c0 = arith.constant 0 : index
185  %arr = memref.alloc() : memref<512x64x8x16xi4>
186  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 16] [1, 1, 1, 1] : memref<512x64x8x16xi4>
187                                                                            to memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
188  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
189  return %ld : i4
190}
191
192// CHECK-LABEL:   func.func @memref_subview_dynamic_offset_i4(
193// CHECK:           %[[ALLOC:.*]] = memref.alloc() : memref<2097152xi8>
194// CHECK:           %[[IDX:.*]] = affine.apply
195// CHECK:           %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, strided<[1], offset: ?>>
196// CHECK:           memref.load %[[SUBVIEW]]
197
198// CHECK32-LABEL:   func.func @memref_subview_dynamic_offset_i4(
199// CHECK32:           %[[ALLOC:.*]] = memref.alloc() : memref<524288xi32>
200// CHECK32:           %[[IDX:.*]] = affine.apply
201// CHECK32:           %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, strided<[1], offset: ?>>
202// CHECK32:           memref.load %[[SUBVIEW]]
203
204// -----
205
206func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
207  %c0 = arith.constant 0 : index
208  %arr = memref.alloc() : memref<40x40xi4>
209  // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
210  %subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
211  %ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
212  return %ld : i4
213}
214
215// -----
216
217func.func @reinterpret_cast_memref_load_0D() -> i4 {
218    %0 = memref.alloc() : memref<5xi4>
219    %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
220    %1 = memref.load %reinterpret_cast_0[] : memref<i4>
221    return %1 : i4
222}
223// CHECK-LABEL: func @reinterpret_cast_memref_load_0D()
224//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
225//       CHECK:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<3xi8> to memref<i8>
226//       CHECK:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i8>
227//       CHECK:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i8 to i4
228//       CHECK:   return %[[TRUNC]]
229
230// CHECK32-LABEL: func @reinterpret_cast_memref_load_0D()
231//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
232//       CHECK32:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<1xi32> to memref<i32>
233//       CHECK32:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i32>
234//       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
235//       CHECK32:   return %[[TRUNC]]
236
237// -----
238
239func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 {
240    %0 = memref.alloc() : memref<5x5xi4>
241    %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [8], sizes: [25], strides: [1] : memref<5x5xi4> to memref<25xi4, strided<[1], offset:8>>
242    %1 = memref.load %reinterpret_cast_0[%arg0] : memref<25xi4, strided<[1], offset:8>>
243    return %1 : i4
244}
245//   CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
246//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
247//       CHECK: func @reinterpret_cast_memref_load_1D(
248//  CHECK-SAME: %[[ARG0:.+]]: index
249//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<13xi8>
250//       CHECK:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [4], sizes: [13], strides: [1] : memref<13xi8> to memref<13xi8, strided<[1], offset: 4>>
251//       CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
252//       CHECK:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<13xi8, strided<[1], offset: 4>>
253//       CHECK:   %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
254//       CHECK:   %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i8
255//       CHECK:   %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i8
256//       CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i8 to i4
257//       CHECK:   return %[[TRUNC]]
258
259//   CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
260//   CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
261//       CHECK32: func @reinterpret_cast_memref_load_1D(
262//  CHECK32-SAME: %[[ARG0:.+]]: index
263//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
264//       CHECK32:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [1], sizes: [4], strides: [1] : memref<4xi32> to memref<4xi32, strided<[1], offset: 1>>
265//       CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
266//       CHECK32:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<4xi32, strided<[1], offset: 1>>
267//       CHECK32:   %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
268//       CHECK32:   %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i32
269//       CHECK32:   %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32
270//       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4
271//       CHECK32:   return %[[TRUNC]]
272
273// -----
274
275func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
276    %0 = memref.alloca() : memref<5xi4>
277    %1 = memref.load %0[%arg0] : memref<5xi4>
278    return %1 : i4
279}
280//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
281//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
282//      CHECK: func @memref_alloca_load_i4(
283// CHECK-SAME:     %[[ARG0:.+]]: index
284//      CHECK:   %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
285//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
286//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
287//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
288//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
289//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
290//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
291//      CHECK:   return %[[TRUNC]]
292
293//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
294//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
295//      CHECK32: func @memref_alloca_load_i4(
296// CHECK32-SAME:     %[[ARG0:.+]]: index
297//      CHECK32:   %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
298//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
299//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
300//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
301//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
302//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
303//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
304//      CHECK32:   return %[[TRUNC]]
305
306// -----
307
308func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
309    %0 = memref.alloc() : memref<5xi4>
310    memref.store %arg1, %0[%arg0] : memref<5xi4>
311    return
312}
313//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
314//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
315//      CHECK: func @memref_store_i4(
316// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
317//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
318//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
319//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
320//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
321//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
322//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
323//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
324//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
325//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
326//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
327//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
328//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
329//      CHECK:   return
330
331//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
332//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
333//      CHECK32: func @memref_store_i4(
334// CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
335//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
336//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
337//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
338//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
339//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
340//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
341//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
342//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
343//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
344//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
345//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
346//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
347//      CHECK32:   return
348
349// -----
350
351func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
352    %0 = memref.alloc() : memref<3x125xi4>
353    memref.assume_alignment %0, 64 : memref<3x125xi4>
354    memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
355    return
356}
357//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
358//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
359//      CHECK: func @memref_store_i4_rank2(
360// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
361//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
362//  CHECK-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
363//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
364//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
365//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
366//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
367//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
368//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
369//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
370//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
371//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
372//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
373//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
374//      CHECK:   return
375
376//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
377//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
378//      CHECK32: func @memref_store_i4_rank2(
379// CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
380//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
381//  CHECK32-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
382//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
383//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
384//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
385//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
386//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
387//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
388//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
389//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
390//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
391//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
392//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
393//      CHECK32:   return
394
395// -----
396
397func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
398  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
399  memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
400  return
401}
402//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
403//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
404//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
405//      CHECK: func @memref_store_i4_dynamic(
406// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
407// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
408// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
409// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
410// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
411//  CHECK-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
412//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
413//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
414//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
415//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
416//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
417//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
418//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
419//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
420//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
421//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
422//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
423//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
424//      CHECK:   return
425
426//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
427//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
428//  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
429//      CHECK32: func @memref_store_i4_dynamic(
430// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
431// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
432// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
433// CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
434// CHECK32-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
435//  CHECK32-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
436//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
437//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
438//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
439//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
440//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
441//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
442//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
443//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
444//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
445//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
446//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
447//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
448//      CHECK32:   return
449
450// -----
451
452func.func @rank_zero_memref_store(%arg0: i4) -> () {
453  %0 = memref.alloc() : memref<i4>
454  memref.store %arg0, %0[] : memref<i4>
455  return
456}
457// CHECK-LABEL: func @rank_zero_memref
458//  CHECK-SAME:     %[[ARG0:.+]]: i4
459//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<i8>
460//       CHECK:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
461//       CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
462//       CHECK:   return
463
464// CHECK32-LABEL: func @rank_zero_memref
465//  CHECK32-SAME:     %[[ARG0:.+]]: i4
466//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<i32>
467//       CHECK32:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
468//       CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
469//       CHECK32:   return
470
471// -----
472
473func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
474  %arr = memref.alloc() : memref<32x8x128xi4>
475  %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
476  %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4>
477  return %1 : i4
478}
479
480// CHECK-LABEL:   func.func @memref_collapse_shape_i4(
481//       CHECK:     %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
482//   CHECK-NOT:     memref.collapse_shape
483//       CHECK:     memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
484
485// CHECK32-LABEL:   func.func @memref_collapse_shape_i4(
486//       CHECK32:     %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
487//   CHECK32-NOT:     memref.collapse_shape
488//       CHECK32:     memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
489
490// -----
491
492func.func @memref_expand_shape_i4(%idx0 : index, %idx1 : index, %idx2 : index) -> i4 {
493  %arr = memref.alloc() : memref<256x128xi4>
494  %expand = memref.expand_shape %arr[[0, 1], [2]] output_shape [32, 8, 128] : memref<256x128xi4> into memref<32x8x128xi4>
495  %1 = memref.load %expand[%idx0, %idx1, %idx2] : memref<32x8x128xi4>
496  return %1 : i4
497}
498
499// CHECK-LABEL:   func.func @memref_expand_shape_i4(
500//       CHECK:     %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
501//   CHECK-NOT:     memref.expand_shape
502//       CHECK:     memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
503
504// CHECK32-LABEL:   func.func @memref_expand_shape_i4(
505//       CHECK32:     %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
506//   CHECK32-NOT:     memref.expand_shape
507//       CHECK32:     memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
508
509// -----
510
511func.func @memref_memory_space_cast_i4(%arg0: memref<32x128xi4, 1>) -> memref<32x128xi4> {
512  %cast = memref.memory_space_cast %arg0 : memref<32x128xi4, 1> to memref<32x128xi4>
513  return %cast : memref<32x128xi4>
514}
515
516// CHECK-LABEL:   func.func @memref_memory_space_cast_i4(
517//  CHECK-SAME:   %[[ARG0:.*]]: memref<2048xi8, 1>
518//       CHECK:     %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<2048xi8, 1> to memref<2048xi8>
519//       CHECK:     return %[[CAST]]
520
521// CHECK32-LABEL:   func.func @memref_memory_space_cast_i4(
522//  CHECK32-SAME:   %[[ARG0:.*]]: memref<512xi32, 1>
523//       CHECK32:     %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<512xi32, 1> to memref<512xi32>
524//       CHECK32:     return %[[CAST]]
525
526// -----
527
528func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>) {
529  memref.copy %arg0, %arg1 : memref<32x128xi4, 1> to memref<32x128xi4>
530  return
531}
532
533// CHECK-LABEL:   func.func @memref_copy_i4(
534//  CHECK-SAME:   %[[ARG0:.*]]: memref<2048xi8, 1>, %[[ARG1:.*]]: memref<2048xi8>
535//       CHECK:     memref.copy %[[ARG0]], %[[ARG1]]
536//       CHECK:     return
537
538// CHECK32-LABEL:   func.func @memref_copy_i4(
539//  CHECK32-SAME:   %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
540//       CHECK32:     memref.copy %[[ARG0]], %[[ARG1]]
541//       CHECK32:     return
542
543// -----
544
545func.func @alloc_non_contiguous() {
546  // expected-error @+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
547  %arr = memref.alloc() : memref<8x8xi4, strided<[1, 8]>>
548  return
549}
550
551// -----
552
553// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
554func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) {
555  return
556}
557