xref: /llvm-project/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir (revision 91c11574e87b5c0f434688edac01e9580ef99a92)
1// RUN: mlir-opt --split-input-file --transform-interpreter %s | FileCheck %s
2
3// CHECK-LABEL: func @matmul_divisible
4//       CHECK:   scf.forall
5//   CHECK-NOT:     memref.copy
6//       CHECK:     linalg.fill
7//       CHECK:     scf.for
8//       CHECK:       memref.alloc() : memref<128x16xf32, 3>
9//       CHECK:       scf.forall
10//       CHECK:         vector.create_mask
11//       CHECK:         vector.transfer_read
12//       CHECK:         vector.transfer_write
13//       CHECK:       memref.alloc() : memref<16x128xf32, 3>
14//       CHECK:       scf.forall
15//       CHECK:         vector.create_mask
16//       CHECK:         vector.transfer_read
17//       CHECK:         vector.transfer_write
18//       CHECK:       memref.alloc() : memref<128x128xf32, 3>
19//       CHECK:       scf.forall
20//       CHECK:         vector.create_mask
21//       CHECK:         vector.transfer_read
22//       CHECK:         vector.transfer_write
23//       CHECK:       linalg.matmul
24//       CHECK:       scf.forall
25//       CHECK:         vector.transfer_read
26//       CHECK:         vector.transfer_write
27func.func @matmul_divisible(%A: tensor<1024x1024xf32>,
28                            %B: tensor<1024x1024xf32>,
29                            %C: tensor<1024x1024xf32>)
30    -> tensor<1024x1024xf32>
31{
32  %cst = arith.constant 0.000000e+00 : f32
33  %0 = linalg.fill ins(%cst : f32)
34                   outs(%C : tensor<1024x1024xf32>)
35      -> tensor<1024x1024xf32>
36  %1 = linalg.matmul ins(%A, %B : tensor<1024x1024xf32>, tensor<1024x1024xf32>)
37                     outs(%0 : tensor<1024x1024xf32>)
38      -> tensor<1024x1024xf32>
39  return %1 : tensor<1024x1024xf32>
40}
41
42module attributes {transform.with_named_sequence} {
43  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
44    // Fuse linalg.fill into linalg.matmul and tile.
45    %matmul_op = transform.structured.match ops{["linalg.matmul"]} in %arg1
46        : (!transform.any_op) -> !transform.any_op
47    %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1
48        : (!transform.any_op) -> !transform.any_op
49    %tiled_matmul_op, %forall_op = transform.structured.tile_using_forall %matmul_op num_threads [] tile_sizes [128, 128](mapping = [#gpu.block<y>, #gpu.block<x>])
50        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
51    %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %fill_op into %forall_op
52        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
53
54    // Tile linalg.matmul a second time.
55    %tiled_linalg_op, %loops = transform.structured.tile_using_for %tiled_matmul_op tile_sizes [0, 0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
56
57    // Pad linalg.matmul.
58    %padded, %pad, %copy_back = transform.structured.pad %tiled_linalg_op
59        {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
60         padding_dimensions=[0, 1, 2], nofold_flags=[1, 1, 1],
61         copy_back_op = "linalg.copy"}
62        : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
63
64    // Map and tile tensor.pad.
65    %pad_forall_op, %tiled_pad_op = transform.structured.gpu.map_copy_to_threads
66        %pad total_num_threads = 32 desired_bit_alignment = 128
67        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
68    transform.foreach %pad_forall_op : !transform.any_op {
69    ^bb2(%arg2 : !transform.any_op):
70      %if_op = transform.structured.match ops{["scf.if"]} in %arg2
71          : (!transform.any_op) -> !transform.any_op
72      // TODO: The scf.if can be avoided with 0x... tensors.
73      transform.scf.take_assumed_branch %if_op take_else_branch
74          : (!transform.any_op) -> ()
75    }
76
77    // Map and tile copy back.
78    %copy_forall_op, %tiled_copy_op = transform.structured.gpu.map_copy_to_threads
79        %copy_back total_num_threads = 32 desired_bit_alignment = 128
80        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
81
82    // Apply masked vectorization to padding ops.
83    transform.structured.vectorize %tiled_pad_op vector_sizes [128, 4]
84        : !transform.any_op
85
86    // Assign shared memory buffer to padding.
87    %buffer, %new_ops = transform.structured.bufferize_to_allocation
88        %pad_forall_op {memory_space = 3, bufferize_destination_only, emit_dealloc}
89        : !transform.any_op
90
91    // Bufferize.
92    %func_op_1 = transform.structured.match ops{["func.func"]} in %arg1
93        : (!transform.any_op) -> !transform.any_op
94    transform.bufferization.eliminate_empty_tensors %func_op_1 : !transform.any_op
95    transform.apply_dce to %func_op_1 : !transform.any_op
96    transform.apply_cse to %func_op_1 : !transform.any_op
97    %bufferized = transform.bufferization.one_shot_bufferize
98        layout{IdentityLayoutMap} %arg1 {bufferize_function_boundaries=true}
99        : (!transform.any_op) -> !transform.any_op
100
101    // Apply vectorization to copy back from shared memory.
102    // TODO: Find a way to retain the handle to linalg.copy throughout
103    // bufferization.
104    %func_op_2 = transform.structured.match ops{["func.func"]} in %bufferized
105        : (!transform.any_op) -> !transform.any_op
106    %bufferized_copy_back = transform.structured.match ops{["linalg.copy"]} in %func_op_2
107        : (!transform.any_op) -> !transform.any_op
108    transform.structured.vectorize
109        %bufferized_copy_back vector_sizes [128, 4] : !transform.any_op
110
111    // Canonicalize, cleanup and vector lowering. This step also removes buffer
112    // self-copies.
113    transform.apply_patterns to %func_op_2 {
114      transform.apply_patterns.canonicalization
115      transform.apply_patterns.vector.lower_masked_transfers
116    } {apply_cse} : !transform.any_op
117    transform.yield
118  }
119}
120
121// -----
122
123// CHECK-LABEL: func @matmul_not_divisible
124//       CHECK:   scf.forall
125//   CHECK-NOT:     memref.copy
126//       CHECK:     linalg.fill
127//       CHECK:     scf.for
128//       CHECK:       memref.alloc() : memref<128x16xf32, 3>
129//       CHECK:       scf.forall
130//       CHECK:         vector.create_mask
131//       CHECK:         vector.transfer_read
132//       CHECK:         vector.transfer_write
133//       CHECK:       memref.alloc() : memref<16x128xf32, 3>
134//       CHECK:       scf.forall
135//       CHECK:         vector.create_mask
136//       CHECK:         vector.transfer_read
137//       CHECK:         vector.transfer_write
138//       CHECK:       memref.alloc() : memref<128x128xf32, 3>
139//       CHECK:       scf.forall
140//       CHECK:         vector.create_mask
141//       CHECK:         vector.transfer_read
142//       CHECK:         vector.transfer_write
143//       CHECK:       linalg.matmul
144//       CHECK:       vector.transfer_read
145//       CHECK:       vector.transfer_write
146func.func @matmul_not_divisible(%A: tensor<1023x1023xf32>,
147                                %B: tensor<1023x1023xf32>,
148                                %C: tensor<1023x1023xf32>)
149    -> tensor<1023x1023xf32>
150{
151  %cst = arith.constant 0.000000e+00 : f32
152  %0 = linalg.fill ins(%cst : f32)
153                   outs(%C : tensor<1023x1023xf32>)
154      -> tensor<1023x1023xf32>
155  %1 = linalg.matmul ins(%A, %B : tensor<1023x1023xf32>, tensor<1023x1023xf32>)
156                     outs(%0 : tensor<1023x1023xf32>)
157      -> tensor<1023x1023xf32>
158  return %1 : tensor<1023x1023xf32>
159}
160
161module attributes {transform.with_named_sequence} {
162  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
163    // Fuse linalg.fill into linalg.matmul and tile.
164    %matmul_op = transform.structured.match ops{["linalg.matmul"]} in %arg1
165        : (!transform.any_op) -> !transform.any_op
166    %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1
167        : (!transform.any_op) -> !transform.any_op
168    %tiled_matmul_op, %forall_op = transform.structured.tile_using_forall %matmul_op num_threads [] tile_sizes [128, 128](mapping = [#gpu.block<y>, #gpu.block<x>])
169        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
170    %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %fill_op into %forall_op
171        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
172
173    // Tile linalg.matmul a second time.
174    %tiled_linalg_op, %loops = transform.structured.tile_using_for %tiled_matmul_op tile_sizes [0, 0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
175
176    // Pad linalg.matmul.
177    %padded, %pad, %copy_back = transform.structured.pad %tiled_linalg_op
178        {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
179         padding_dimensions=[0, 1, 2], nofold_flags=[1, 1, 1],
180         copy_back_op = "linalg.copy"}
181        : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
182
183    // Map and tile tensor.pad.
184    %pad_forall_op, %tiled_pad_op = transform.structured.gpu.map_copy_to_threads
185        %pad total_num_threads = 32 desired_bit_alignment = 128
186        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
187    transform.foreach %pad_forall_op : !transform.any_op {
188    ^bb2(%arg2 : !transform.any_op):
189      %if_op = transform.structured.match ops{["scf.if"]} in %arg2
190          : (!transform.any_op) -> !transform.any_op
191      // TODO: The scf.if can be avoided with 0x... tensors.
192      transform.scf.take_assumed_branch %if_op take_else_branch
193          : (!transform.any_op) -> ()
194    }
195
196    // Apply masked vectorization to padding ops.
197    transform.structured.vectorize %tiled_pad_op vector_sizes [128, 4]
198        : !transform.any_op
199
200    // Assign shared memory buffer to padding.
201    %buffer, %new_ops = transform.structured.bufferize_to_allocation
202        %pad_forall_op {memory_space = 3, bufferize_destination_only, emit_dealloc}
203        : !transform.any_op
204
205    // Bufferize.
206    %func_op_1 = transform.structured.match ops{["func.func"]} in %arg1
207        : (!transform.any_op) -> !transform.any_op
208    transform.bufferization.eliminate_empty_tensors %func_op_1 : !transform.any_op
209    transform.apply_dce to %func_op_1 : !transform.any_op
210    transform.apply_cse to %func_op_1 : !transform.any_op
211    %bufferized = transform.bufferization.one_shot_bufferize
212        layout{IdentityLayoutMap} %arg1 {bufferize_function_boundaries=true}
213        : (!transform.any_op) -> !transform.any_op
214
215    // Apply vectorization to copy back from shared memory.
216    // TODO: Find a way to retain the handle to linalg.copy throughout
217    // bufferization.
218    %func_op_2 = transform.structured.match ops{["func.func"]} in %bufferized
219        : (!transform.any_op) -> !transform.any_op
220    %bufferized_copy_back = transform.structured.match ops{["linalg.copy"]} in %func_op_2
221        : (!transform.any_op) -> !transform.any_op
222    transform.structured.vectorize
223        %bufferized_copy_back vector_sizes [128, 4] : !transform.any_op
224
225    // Canonicalize, cleanup and vector lowering. This step also removes buffer
226    // self-copies.
227    transform.apply_patterns to %func_op_2 {
228      transform.apply_patterns.canonicalization
229      transform.apply_patterns.vector.lower_masked_transfers
230    } {apply_cse} : !transform.any_op
231    transform.yield
232  }
233}
234