1// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s 2 3func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?>>, 4 %arg1: memref<?x?xf32, strided<[?, 1], offset: ?>>, 5 %arg2: memref<?x?xf32, strided<[?, 1], offset: ?>>) { 6 %c2000 = arith.constant 2000 : index 7 %c3000 = arith.constant 3000 : index 8 %c4000 = arith.constant 4000 : index 9 %c0 = arith.constant 0 : index 10 %c1 = arith.constant 1 : index 11 %0 = memref.dim %arg0, %c0 : memref<?x?xf32, strided<[?, 1], offset: ?>> 12 %1 = memref.dim %arg0, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>> 13 %2 = memref.dim %arg1, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>> 14 scf.for %arg3 = %c0 to %0 step %c2000 { 15 scf.for %arg4 = %c0 to %2 step %c3000 { 16 scf.for %arg5 = %c0 to %1 step %c4000 { 17 %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : 18 memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 19 %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : 20 memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 21 %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : 22 memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 23 linalg.matmul ins(%3, %4: memref<?x?xf32, strided<[?, ?], offset: ?>>, 24 memref<?x?xf32, strided<[?, ?], offset: ?>>) 25 outs(%5: memref<?x?xf32, strided<[?, ?], offset: ?>>) 26 } 27 } 28 } 29 return 30} 31// CHECK-LABEL: func @promote_subview_matmul 32// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 33// CHECK-DAG: %[[c2000:.*]] = arith.constant 2000 : index 34// CHECK-DAG: %[[c3000:.*]] = arith.constant 3000 : index 35// CHECK-DAG: %[[c4000:.*]] = arith.constant 4000 : index 36// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { 37// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { 38// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { 39// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}> 40// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}> 41// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}> 42// CHECK: %[[a0:.*]] = memref.alloc() : memref<32000000xi8> 43// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32> 44// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] 45// CHECK-SAME: memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>> 46// CHECK: %[[a1:.*]] = memref.alloc() : memref<48000000xi8> 47// CHECK: %[[v1:.*]] = memref.view %[[a1]]{{.*}} : memref<48000000xi8> to memref<?x?xf32> 48// CHECK: %[[l1:.*]] = memref.subview %[[v1]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] 49// CHECK-SAME: memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>> 50// CHECK: %[[a2:.*]] = memref.alloc() : memref<24000000xi8> 51// CHECK: %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32> 52// CHECK: %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] 53// CHECK-SAME: memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>> 54// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>) 55// CHECK: linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>) 56// CHECK: linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>) 57// CHECK: linalg.matmul 58// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>) 59// CHECK-SAME: outs(%[[v2]] : memref<?x?xf32>) 60 61module attributes {transform.with_named_sequence} { 62 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 63 %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op 64 %1 = transform.structured.promote %0 { operands_to_promote = [0, 1, 2], use_full_tiles_by_default } : (!transform.any_op) -> !transform.any_op 65 transform.yield 66 } 67} 68 69// ----- 70 71func.func @promote_first_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?>>, 72 %arg1: memref<?x?xf32, strided<[?, 1], offset: ?>>, 73 %arg2: memref<?x?xf32, strided<[?, 1], offset: ?>>) { 74 %c2000 = arith.constant 2000 : index 75 %c3000 = arith.constant 3000 : index 76 %c4000 = arith.constant 4000 : index 77 %c0 = arith.constant 0 : index 78 %c1 = arith.constant 1 : index 79 %0 = memref.dim %arg0, %c0 : memref<?x?xf32, strided<[?, 1], offset: ?>> 80 %1 = memref.dim %arg0, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>> 81 %2 = memref.dim %arg1, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>> 82 scf.for %arg3 = %c0 to %0 step %c2000 { 83 scf.for %arg4 = %c0 to %2 step %c3000 { 84 scf.for %arg5 = %c0 to %1 step %c4000 { 85 %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : 86 memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 87 %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : 88 memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 89 %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : 90 memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 91 linalg.matmul {__internal_linalg_transform__ = "_promote_first_view_"} 92 ins(%3, %4: memref<?x?xf32, strided<[?, ?], offset: ?>>, 93 memref<?x?xf32, strided<[?, ?], offset: ?>>) 94 outs(%5: memref<?x?xf32, strided<[?, ?], offset: ?>>) 95 } 96 } 97 } 98 return 99} 100// CHECK-LABEL: func @promote_first_subview_matmul 101// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 102// CHECK-DAG: %[[c2000:.*]] = arith.constant 2000 : index 103// CHECK-DAG: %[[c3000:.*]] = arith.constant 3000 : index 104// CHECK-DAG: %[[c4000:.*]] = arith.constant 4000 : index 105// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { 106// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { 107// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { 108// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}> 109// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}> 110// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}> 111// CHECK: %[[a0:.*]] = memref.alloc() : memref<32000000xi8> 112// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32> 113// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>> 114// CHECK-NOT: memref.alloc 115// CHECK-NOT: memref.view 116// CHECK-NOT: memref.subview 117// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>) 118// CHECK-NOT: linalg.copy 119// CHECK: linalg.matmul 120// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref<?x?xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>) 121// CHECK-SAME: outs(%[[s2]] : memref<?x?xf32, strided<[?, ?], offset: ?>>) 122 123module attributes {transform.with_named_sequence} { 124 transform.named_sequence @__transform_main(%arg1: !transform.any_op) { 125 %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op 126 %1 = transform.structured.promote %0 { operands_to_promote = [0], use_full_tiles_by_default } : (!transform.any_op) -> !transform.any_op 127 transform.yield 128 } 129} 130 131// ----- 132 133func.func @aligned_promote_fill(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?>>) { 134 %c2000 = arith.constant 2000 : index 135 %c4000 = arith.constant 4000 : index 136 %c0 = arith.constant 0 : index 137 %c1 = arith.constant 1 : index 138 %cf = arith.constant 1.0 : f32 139 %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : 140 memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 141 linalg.fill 142 ins(%cf : f32) outs(%3 : memref<?x?xf32, strided<[?, ?], offset: ?>>) 143 return 144} 145// CHECK-LABEL: func @aligned_promote_fill 146// CHECK: %[[cf:.*]] = arith.constant 1.{{.*}} : f32 147// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}> 148// CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8> 149// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32> 150// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>> 151// CHECK: linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref<?x?xf32>) 152// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>) 153// CHECK: linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref<?x?xf32>) 154 155module attributes {transform.with_named_sequence} { 156 transform.named_sequence @__transform_main(%arg1: !transform.any_op) { 157 %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op 158 %1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32} : (!transform.any_op) -> !transform.any_op 159 transform.yield 160 } 161} 162 163// ----- 164 165func.func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, strided<[?, 1], offset: ?>>) { 166 %c2000 = arith.constant 2000 : index 167 %c4000 = arith.constant 4000 : index 168 %c0 = arith.constant 0 : index 169 %c1 = arith.constant 1 : index 170 %cf = arith.constant 1.0 : f32 171 %cc = complex.create %cf, %cf : complex<f32> 172 %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : 173 memref<?x?xcomplex<f32>, strided<[?, 1], offset: ?>> to memref<?x?xcomplex<f32>, strided<[?, ?], offset: ?>> 174 linalg.fill ins(%cc : complex<f32>) 175 outs(%3 : memref<?x?xcomplex<f32>, strided<[?, ?], offset: ?>>) 176 return 177} 178// CHECK-LABEL: func @aligned_promote_fill_complex 179// CHECK: %[[cc:.*]] = complex.create {{.*}} : complex<f32> 180// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xcomplex<f32>, strided{{.*}}> to memref<?x?xcomplex<f32>, strided{{.*}}> 181// CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<64000000xi8> 182// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref<?x?xcomplex<f32>> 183// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, strided<[?, 1]>> 184// CHECK: linalg.fill ins({{.*}} : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>) 185// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xcomplex<f32>, strided{{.*}}>) outs(%[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}>) 186// CHECK: linalg.fill ins(%[[cc]] : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>) 187 188module attributes {transform.with_named_sequence} { 189 transform.named_sequence @__transform_main(%arg1: !transform.any_op) { 190 %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op 191 %1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32} : (!transform.any_op) -> !transform.any_op 192 transform.yield 193 } 194} 195