xref: /llvm-project/mlir/test/Dialect/Linalg/transform-patterns.mlir (revision f59b0c76030aff268b78d475e219708d06b982b5)
1// RUN: mlir-opt %s -transform-interpreter -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s
2
3func.func @dot(%x: memref<?xf32, strided<[1], offset: ?>>,
4          %y: memref<?xf32, strided<[1], offset: ?>>,
5          %v: memref<f32>) {
6  linalg.dot ins(%x, %y: memref<?xf32, strided<[1], offset: ?>>,
7                         memref<?xf32, strided<[1], offset: ?>>)
8            outs(%v: memref<f32>)
9  return
10}
11
12module attributes {transform.with_named_sequence} {
13  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
14      %0 = transform.structured.match ops{["linalg.dot"]} in %arg1 : (!transform.any_op) -> !transform.any_op
15      %1, %loop = transform.structured.tile_using_for %0 tile_sizes [8000] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
16      transform.yield
17  }
18}
19
20// CHECK-LABEL: func @dot
21// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
22// CHECK-DAG:     %[[c8000:.*]] = arith.constant 8000 : index
23// CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
24// CHECK:           linalg.dot
25
26// -----
27
28func.func @matvec(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
29             %x: memref<?xf32, strided<[1], offset: ?>>,
30             %y: memref<?xf32, strided<[1], offset: ?>>) {
31  linalg.matvec
32    ins(%A, %x: memref<?x?xf32, strided<[?, 1], offset: ?>>,
33                memref<?xf32, strided<[1], offset: ?>>)
34    outs(%y: memref<?xf32, strided<[1], offset: ?>>)
35  return
36}
37
38module attributes {transform.with_named_sequence} {
39  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
40      %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 : (!transform.any_op) -> !transform.any_op
41      %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [5, 6] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
42      transform.yield
43  }
44}
45
46// CHECK-LABEL: func @matvec
47// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
48// CHECK-DAG:     %[[c5:.*]] = arith.constant 5 : index
49// CHECK-DAG:     %[[c6:.*]] = arith.constant 6 : index
50// CHECK:         scf.for {{.*}} step %[[c5]]
51// CHECK:           scf.for {{.*}} step %[[c6]]
52// CHECK:             linalg.matvec
53// CHECK:               ins({{.*}}: memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?xf32, strided<[1], offset: ?>>)
54// CHECK:              outs({{.*}}: memref<?xf32, strided<[1], offset: ?>>)
55
56// -----
57
58func.func @matmul(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
59             %B: memref<?x?xf32, strided<[?, 1], offset: ?>>,
60             %C: memref<?x?xf32, strided<[?, 1], offset: ?>>) {
61  linalg.matmul ins(%A, %B: memref<?x?xf32, strided<[?, 1], offset: ?>>,
62                            memref<?x?xf32, strided<[?, 1], offset: ?>>)
63               outs(%C: memref<?x?xf32, strided<[?, 1], offset: ?>>)
64  return
65}
66
67module attributes {transform.with_named_sequence} {
68  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
69      %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
70      %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2000, 3000, 4000] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
71      %2, %loops_2:3 = transform.structured.tile_using_for %1 tile_sizes [200, 300, 400] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
72      %3, %loops_3:3 = transform.structured.tile_using_for %2 tile_sizes [20, 30, 40] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
73      %4, %loops_4:3 = transform.structured.tile_using_for %3 tile_sizes [2, 3, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
74      transform.yield
75  }
76}
77
78// CHECK-LABEL: func @matmul
79// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
80// CHECK-DAG:     %[[c2:.*]] = arith.constant 2 : index
81// CHECK-DAG:     %[[c3:.*]] = arith.constant 3 : index
82// CHECK-DAG:     %[[c4:.*]] = arith.constant 4 : index
83// CHECK-DAG:     %[[c20:.*]] = arith.constant 20 : index
84// CHECK-DAG:     %[[c30:.*]] = arith.constant 30 : index
85// CHECK-DAG:     %[[c40:.*]] = arith.constant 40 : index
86// CHECK-DAG:     %[[c200:.*]] = arith.constant 200 : index
87// CHECK-DAG:     %[[c300:.*]] = arith.constant 300 : index
88// CHECK-DAG:     %[[c400:.*]] = arith.constant 400 : index
89// CHECK-DAG:     %[[c2000:.*]] = arith.constant 2000 : index
90// CHECK-DAG:     %[[c3000:.*]] = arith.constant 3000 : index
91// CHECK-DAG:     %[[c4000:.*]] = arith.constant 4000 : index
92// CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
93// CHECK:           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
94// CHECK:             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
95// CHECK:               scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] {
96// CHECK:                 scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] {
97// CHECK:                   scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] {
98// CHECK:                     scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
99// CHECK:                       scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
100// CHECK:                         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
101// CHECK:                           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] {
102// CHECK:                             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] {
103// CHECK:                               scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
104// CHECK:                                 linalg.matmul
105// CHECK:                                   ins({{.*}}: memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?x?xf32, strided<[?, 1], offset: ?>>)
106// CHECK:                                  outs({{.*}}: memref<?x?xf32, strided<[?, 1], offset: ?>>)
107
108// -----
109
110// Map corresponding to a 2D memory access where the stride along the last dim is known to be 1.
111// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
112// CHECK-DAG: #[[$nm:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)>
113// CHECK-DAG: #[[$km:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
114
115#matmul_accesses = [
116  affine_map<(m, n, k) -> (m, k)>,
117  affine_map<(m, n, k) -> (k, n)>,
118  affine_map<(m, n, k) -> (m, n)>
119]
120#generic_matmul_trait = {
121  indexing_maps = #matmul_accesses,
122  library_call = "linalg_matmul",
123  iterator_types = ["parallel", "parallel", "reduction"]
124}
125func.func @permute_generic(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
126           %B: memref<?x?xf32, strided<[?, 1], offset: ?>>,
127           %C: memref<?x?xf32, strided<[?, 1], offset: ?>>) {
128  linalg.generic #generic_matmul_trait
129    ins(%A, %B : memref<?x?xf32, strided<[?, 1], offset: ?>>,
130                 memref<?x?xf32, strided<[?, 1], offset: ?>>)
131   outs(%C : memref<?x?xf32, strided<[?, 1], offset: ?>>) {
132    ^bb(%a: f32, %b: f32, %c: f32):
133      %d = arith.mulf %a, %b: f32
134      %e = arith.addf %c, %d: f32
135      linalg.yield %e: f32
136  }
137  return
138}
139
140module attributes {transform.with_named_sequence} {
141  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
142    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
143    transform.structured.interchange %0 iterator_interchange = [1, 2, 0] : (!transform.any_op) -> !transform.any_op
144    transform.yield
145  }
146}
147
148// CHECK-LABEL:  func @permute_generic
149// CHECK:        linalg.generic {
150// CHECK-SAME:   indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
151// CHECK-SAME:   iterator_types = ["parallel", "reduction", "parallel"],
152// CHECK-SAME:   library_call = "linalg_matmul"}
153// CHECK:          memref<?x?xf32, strided<[?, 1], offset: ?>>,
154// CHECK-SAME:     memref<?x?xf32, strided<[?, 1], offset: ?>>
155// CHECK-SAME:     memref<?x?xf32, strided<[?, 1], offset: ?>>
156
157// -----
158
159func.func @matvec_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
160             %x: memref<?xf32, strided<[1], offset: ?>>,
161             %y: memref<?xf32, strided<[1], offset: ?>>) {
162  linalg.matvec ins(%A, %x: memref<?x?xf32, strided<[?, 1], offset: ?>>,
163                            memref<?xf32, strided<[1], offset: ?>>)
164               outs(%y: memref<?xf32, strided<[1], offset: ?>>)
165  return
166}
167
168module attributes {transform.with_named_sequence} {
169  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
170      %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 : (!transform.any_op) -> !transform.any_op
171      %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [5, 6] interchange = [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
172      transform.yield
173  }
174}
175
176// CHECK-LABEL: func @matvec_perm
177// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
178// CHECK-DAG:     %[[c5:.*]] = arith.constant 5 : index
179// CHECK-DAG:     %[[c6:.*]] = arith.constant 6 : index
180// CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
181// CHECK:           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
182// CHECK:             linalg.matvec
183// CHECK:               ins({{.*}}: memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?xf32, strided<[1], offset: ?>>)
184// CHECK:              outs({{.*}}: memref<?xf32, strided<[1], offset: ?>>)
185
186// -----
187
188func.func @matmul_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
189             %B: memref<?x?xf32, strided<[?, 1], offset: ?>>,
190             %C: memref<?x?xf32, strided<[?, 1], offset: ?>>) {
191  linalg.matmul ins(%A, %B: memref<?x?xf32, strided<[?, 1], offset: ?>>,
192                            memref<?x?xf32, strided<[?, 1], offset: ?>>)
193               outs(%C : memref<?x?xf32, strided<[?, 1], offset: ?>>)
194  return
195}
196
197module attributes {transform.with_named_sequence} {
198  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
199      %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
200      %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2000, 3000, 4000] interchange = [1, 2, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
201      %2, %loops_2:3 = transform.structured.tile_using_for %1 tile_sizes [200, 300, 400] interchange = [1, 0, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
202      %3, %loops_3:3 = transform.structured.tile_using_for %2 tile_sizes [20, 30, 40] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
203      transform.yield
204  }
205}
206
207// CHECK-LABEL: func @matmul_perm
208// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
209// CHECK-DAG:     %[[c20:.*]] = arith.constant 20 : index
210// CHECK-DAG:     %[[c30:.*]] = arith.constant 30 : index
211// CHECK-DAG:     %[[c40:.*]] = arith.constant 40 : index
212// CHECK-DAG:     %[[c200:.*]] = arith.constant 200 : index
213// CHECK-DAG:     %[[c300:.*]] = arith.constant 300 : index
214// CHECK-DAG:     %[[c400:.*]] = arith.constant 400 : index
215// CHECK-DAG:     %[[c2000:.*]] = arith.constant 2000 : index
216// CHECK-DAG:     %[[c3000:.*]] = arith.constant 3000 : index
217// CHECK-DAG:     %[[c4000:.*]] = arith.constant 4000 : index
218// CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
219// CHECK:           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
220// CHECK:             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
221// CHECK:               scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] {
222// CHECK:                 scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] {
223// CHECK:                   scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] {
224// CHECK:                     scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
225// CHECK:                       scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
226// CHECK:                         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
227// CHECK:                                 linalg.matmul
228// CHECK:                                  ins({{.*}}: memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?x?xf32, strided<[?, 1], offset: ?>>)
229// CHECK:                                   outs({{.*}}: memref<?x?xf32, strided<[?, 1], offset: ?>>)
230