xref: /llvm-project/mlir/test/Dialect/Linalg/transform-promotion.mlir (revision 6dc8de7a0abc7df8295273694fd9b951ed33708f)
1// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
2
3func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?>>,
4                             %arg1: memref<?x?xf32, strided<[?, 1], offset: ?>>,
5                             %arg2: memref<?x?xf32, strided<[?, 1], offset: ?>>) {
6  %c2000 = arith.constant 2000 : index
7  %c3000 = arith.constant 3000 : index
8  %c4000 = arith.constant 4000 : index
9  %c0 = arith.constant 0 : index
10  %c1 = arith.constant 1 : index
11  %0 = memref.dim %arg0, %c0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
12  %1 = memref.dim %arg0, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>>
13  %2 = memref.dim %arg1, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>>
14  scf.for %arg3 = %c0 to %0 step %c2000 {
15    scf.for %arg4 = %c0 to %2 step %c3000 {
16      scf.for %arg5 = %c0 to %1 step %c4000 {
17        %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] :
18             memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
19        %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] :
20             memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
21        %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
22             memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
23        linalg.matmul ins(%3, %4: memref<?x?xf32, strided<[?, ?], offset: ?>>,
24                                  memref<?x?xf32, strided<[?, ?], offset: ?>>)
25                     outs(%5: memref<?x?xf32, strided<[?, ?], offset: ?>>)
26      }
27    }
28  }
29  return
30}
31// CHECK-LABEL: func @promote_subview_matmul
32// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
33// CHECK-DAG:     %[[c2000:.*]] = arith.constant 2000 : index
34// CHECK-DAG:     %[[c3000:.*]] = arith.constant 3000 : index
35// CHECK-DAG:     %[[c4000:.*]] = arith.constant 4000 : index
36// CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
37// CHECK:           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
38// CHECK:             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
39// CHECK:               %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
40// CHECK:               %[[s1:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
41// CHECK:               %[[s2:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
42// CHECK:               %[[a0:.*]] = memref.alloc() : memref<32000000xi8>
43// CHECK:               %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
44// CHECK:               %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
45// CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
46// CHECK:               %[[a1:.*]] = memref.alloc() : memref<48000000xi8>
47// CHECK:               %[[v1:.*]] = memref.view %[[a1]]{{.*}} : memref<48000000xi8> to memref<?x?xf32>
48// CHECK:               %[[l1:.*]] = memref.subview %[[v1]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
49// CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
50// CHECK:               %[[a2:.*]] = memref.alloc() : memref<24000000xi8>
51// CHECK:               %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32>
52// CHECK:               %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
53// CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
54// CHECK:               linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
55// CHECK:               linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>)
56// CHECK:               linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
57// CHECK:               linalg.matmul
58// CHECK-SAME:                 ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
59// CHECK-SAME:                outs(%[[v2]] : memref<?x?xf32>)
60
61module attributes {transform.with_named_sequence} {
62  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
63    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
64    %1 = transform.structured.promote %0 { operands_to_promote = [0, 1, 2], use_full_tiles_by_default } : (!transform.any_op) -> !transform.any_op
65    transform.yield
66  }
67}
68
69// -----
70
71func.func @promote_first_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?>>,
72                             %arg1: memref<?x?xf32, strided<[?, 1], offset: ?>>,
73                             %arg2: memref<?x?xf32, strided<[?, 1], offset: ?>>) {
74  %c2000 = arith.constant 2000 : index
75  %c3000 = arith.constant 3000 : index
76  %c4000 = arith.constant 4000 : index
77  %c0 = arith.constant 0 : index
78  %c1 = arith.constant 1 : index
79  %0 = memref.dim %arg0, %c0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
80  %1 = memref.dim %arg0, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>>
81  %2 = memref.dim %arg1, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>>
82  scf.for %arg3 = %c0 to %0 step %c2000 {
83    scf.for %arg4 = %c0 to %2 step %c3000 {
84      scf.for %arg5 = %c0 to %1 step %c4000 {
85        %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] :
86             memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
87        %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] :
88             memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
89        %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
90             memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
91        linalg.matmul {__internal_linalg_transform__ = "_promote_first_view_"}
92          ins(%3, %4: memref<?x?xf32, strided<[?, ?], offset: ?>>,
93                      memref<?x?xf32, strided<[?, ?], offset: ?>>)
94         outs(%5: memref<?x?xf32, strided<[?, ?], offset: ?>>)
95      }
96    }
97  }
98  return
99}
100// CHECK-LABEL: func @promote_first_subview_matmul
101// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
102// CHECK-DAG:     %[[c2000:.*]] = arith.constant 2000 : index
103// CHECK-DAG:     %[[c3000:.*]] = arith.constant 3000 : index
104// CHECK-DAG:     %[[c4000:.*]] = arith.constant 4000 : index
105// CHECK:   scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
106// CHECK:     scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
107// CHECK:       scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
108// CHECK:         %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
109// CHECK:         %[[s1:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
110// CHECK:         %[[s2:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
111// CHECK:         %[[a0:.*]] = memref.alloc() : memref<32000000xi8>
112// CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
113// CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
114// CHECK-NOT:     memref.alloc
115// CHECK-NOT:     memref.view
116// CHECK-NOT:     memref.subview
117// CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
118// CHECK-NOT:     linalg.copy
119// CHECK:         linalg.matmul
120// CHECK-SAME:           ins(%[[v0]], %[[s1]] : memref<?x?xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>)
121// CHECK-SAME:          outs(%[[s2]] : memref<?x?xf32, strided<[?, ?], offset: ?>>)
122
123module attributes {transform.with_named_sequence} {
124  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
125    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
126    %1 = transform.structured.promote %0 { operands_to_promote = [0], use_full_tiles_by_default } : (!transform.any_op) -> !transform.any_op
127    transform.yield
128  }
129}
130
131// -----
132
133func.func @aligned_promote_fill(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?>>) {
134  %c2000 = arith.constant 2000 : index
135  %c4000 = arith.constant 4000 : index
136  %c0 = arith.constant 0 : index
137  %c1 = arith.constant 1 : index
138  %cf = arith.constant 1.0 : f32
139  %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] :
140         memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
141  linalg.fill
142   ins(%cf : f32) outs(%3 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
143  return
144}
145// CHECK-LABEL: func @aligned_promote_fill
146// CHECK:         %[[cf:.*]] = arith.constant 1.{{.*}} : f32
147// CHECK:         %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
148// CHECK:         %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8>
149// CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
150// CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
151// CHECK:         linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref<?x?xf32>)
152// CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
153// CHECK:         linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref<?x?xf32>)
154
155module attributes {transform.with_named_sequence} {
156  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
157    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
158    %1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32} : (!transform.any_op) -> !transform.any_op
159    transform.yield
160  }
161}
162
163// -----
164
165func.func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, strided<[?, 1], offset: ?>>) {
166  %c2000 = arith.constant 2000 : index
167  %c4000 = arith.constant 4000 : index
168  %c0 = arith.constant 0 : index
169  %c1 = arith.constant 1 : index
170  %cf = arith.constant 1.0 : f32
171  %cc = complex.create %cf, %cf : complex<f32>
172  %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] :
173         memref<?x?xcomplex<f32>, strided<[?, 1], offset: ?>> to memref<?x?xcomplex<f32>, strided<[?, ?], offset: ?>>
174  linalg.fill ins(%cc : complex<f32>)
175             outs(%3 : memref<?x?xcomplex<f32>, strided<[?, ?], offset: ?>>)
176  return
177}
178// CHECK-LABEL: func @aligned_promote_fill_complex
179// CHECK:         %[[cc:.*]] = complex.create {{.*}} : complex<f32>
180// CHECK:         %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xcomplex<f32>, strided{{.*}}> to memref<?x?xcomplex<f32>, strided{{.*}}>
181// CHECK:         %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<64000000xi8>
182// CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref<?x?xcomplex<f32>>
183// CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, strided<[?, 1]>>
184// CHECK:         linalg.fill ins({{.*}} : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
185// CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xcomplex<f32>, strided{{.*}}>) outs(%[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}>)
186// CHECK:         linalg.fill ins(%[[cc]] : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
187
188module attributes {transform.with_named_sequence} {
189  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
190    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
191    %1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32} : (!transform.any_op) -> !transform.any_op
192    transform.yield
193  }
194}
195