xref: /llvm-project/mlir/test/Dialect/GPU/decompose-memrefs.mlir (revision 889b67c9d30e3024a1317431d66c22599f6c2011)
1// RUN: mlir-opt -gpu-decompose-memrefs -allow-unregistered-dialect -split-input-file %s | FileCheck %s
2
3//       CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
4//       CHECK: @decompose_store
5//  CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref<?x?x?xf32>)
6//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
7//       CHECK:  gpu.launch
8//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
9//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
10//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
11//       CHECK:  memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
12func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
13  %c0 = arith.constant 0 : index
14  %c1 = arith.constant 1 : index
15  %c2 = arith.constant 2 : index
16  %block_dim0 = memref.dim %arg1, %c0 : memref<?x?x?xf32>
17  %block_dim1 = memref.dim %arg1, %c1 : memref<?x?x?xf32>
18  %block_dim2 = memref.dim %arg1, %c2 : memref<?x?x?xf32>
19  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
20             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
21    memref.store %arg0, %arg1[%tx, %ty, %tz] : memref<?x?x?xf32>
22    gpu.terminator
23  }
24  return
25}
26
27// -----
28
29//       CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
30//       CHECK: @decompose_store_strided
31//  CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>)
32//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
33//       CHECK:  gpu.launch
34//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
35//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2]
36//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
37//       CHECK:  memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
38func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) {
39  %c0 = arith.constant 0 : index
40  %c1 = arith.constant 1 : index
41  %c2 = arith.constant 2 : index
42  %block_dim0 = memref.dim %arg1, %c0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
43  %block_dim1 = memref.dim %arg1, %c1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
44  %block_dim2 = memref.dim %arg1, %c2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
45  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
46             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
47    memref.store %arg0, %arg1[%tx, %ty, %tz] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
48    gpu.terminator
49  }
50  return
51}
52
53// -----
54
55//       CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
56//       CHECK: @decompose_load
57//  CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>)
58//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
59//       CHECK:  gpu.launch
60//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
61//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
62//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
63//       CHECK:  %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32, strided<[], offset: ?>>
64//       CHECK:  "test.test"(%[[RES]]) : (f32) -> ()
65func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
66  %c0 = arith.constant 0 : index
67  %c1 = arith.constant 1 : index
68  %c2 = arith.constant 2 : index
69  %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
70  %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
71  %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
72  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
73             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
74    %res = memref.load %arg0[%tx, %ty, %tz] : memref<?x?x?xf32>
75    "test.test"(%res) : (f32) -> ()
76    gpu.terminator
77  }
78  return
79}
80
81// -----
82
83//       CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
84//       CHECK: @decompose_subview
85//  CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>)
86//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
87//       CHECK:  gpu.launch
88//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
89//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
90//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
91//       CHECK:  "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
92func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
93  %c0 = arith.constant 0 : index
94  %c1 = arith.constant 1 : index
95  %c2 = arith.constant 2 : index
96  %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
97  %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
98  %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
99  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
100             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
101    %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
102    "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
103    gpu.terminator
104  }
105  return
106}
107
108// -----
109
110//       CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 2)>
111//       CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
112//       CHECK: #[[MAP2:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
113//       CHECK: @decompose_subview_strided
114//  CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>)
115//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
116//       CHECK:  gpu.launch
117//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
118//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[STRIDES]]#0]
119//       CHECK:  %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1]
120//       CHECK:  %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
121//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4]
122//       CHECK:  "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
123func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
124  %c0 = arith.constant 0 : index
125  %c1 = arith.constant 1 : index
126  %c2 = arith.constant 2 : index
127  %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
128  %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
129  %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
130  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
131             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
132    %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>
133    "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
134    gpu.terminator
135  }
136  return
137}
138