xref: /llvm-project/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir (revision 76ead96c1d06ee0d828238bce96d0107e650b5fa)
1// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s
2
3func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>,
4    %init0 : tensor<?x?xf32>, %init1 : tensor<?x?xf32>)
5    -> (tensor<?x?xf32>, tensor<?x?xf32>) {
6  %c0 = arith.constant 0 : index
7  %c1 = arith.constant 1 : index
8  %cst = arith.constant 0.0 : f32
9  %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
10  %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
11  %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
12  %gemm0 = linalg.matmul
13      ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
14  %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
15  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
16  %gemm1 = linalg.matmul
17      ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
18  return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
19}
20
21module attributes {transform.with_named_sequence} {
22  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
23    %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1
24      : (!transform.any_op) -> !transform.any_op
25    %mm1, %mm2 = transform.split_handle %matmuls
26      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
27    %a, %b = transform.test.fuse_and_yield %mm2 [10] use_forall true
28      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
29    transform.yield
30  }
31}
32//      CHECK: func.func @gemm_gemm_fusion_yield_both(
33// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
34// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
35// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
36// CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
37// CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
38//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
39//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
40//      CHECK:   %[[RESULT:.+]]:2 = scf.forall (%[[IV:[a-zA-Z0-9]+]]) =
41// CHECK-SAME:       shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
42//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
43//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
44//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
45//      CHECK:     %[[FILL0_TILE:.+]] = linalg.fill
46// CHECK-SAME:         outs(%[[INIT0_TILE]] :
47//      CHECK:     %[[GEMM0_TILE:.+]] = linalg.matmul
48// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
49// CHECK-SAME:         outs(%[[FILL0_TILE]] :
50//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
51//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
52//      CHECK:     %[[FILL1_TILE:.+]] = linalg.fill
53// CHECK-SAME:         outs(%[[INIT1_TILE]] :
54//      CHECK:     %[[GEMM1_TILE:.+]] = linalg.matmul
55// CHECK-SAME:         ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
56// CHECK-SAME:         outs(%[[FILL1_TILE]] :
57//      CHECK:     scf.forall.in_parallel {
58//      CHECK:       tensor.parallel_insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0]
59//      CHECK:       tensor.parallel_insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
60//      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#0
61