xref: /llvm-project/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir (revision 7ef08eacd5e125eca0881c00f83c4b108ba7c502)
1// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s
2
3func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>,
4    %init0 : tensor<?x?xf32>, %init1 : tensor<?x?xf32>)
5    -> (tensor<?x?xf32>, tensor<?x?xf32>) {
6  %c0 = arith.constant 0 : index
7  %c1 = arith.constant 1 : index
8  %cst = arith.constant 0.0 : f32
9  %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
10  %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
11  %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
12  %gemm0 = linalg.matmul
13      ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
14  %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
15  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
16  %gemm1 = linalg.matmul
17      ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
18  return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
19}
20
21module attributes {transform.with_named_sequence} {
22  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
23    %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1
24      : (!transform.any_op) -> !transform.any_op
25    %mm1, %mm2 = transform.split_handle %matmuls
26      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
27    %a, %b = transform.test.fuse_and_yield %mm2 [10]
28      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
29    transform.yield
30  }
31}
32//      CHECK: func.func @gemm_gemm_fusion_yield_both(
33// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
34// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
35// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
36// CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
37// CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
38//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
39//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
40//      CHECK:   %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] =
41// CHECK-SAME:       iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
42//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
43//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
44//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
45//      CHECK:     %[[FILL0_TILE:.+]] = linalg.fill
46// CHECK-SAME:         outs(%[[INIT0_TILE]] :
47//      CHECK:     %[[GEMM0_TILE:.+]] = linalg.matmul
48// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
49// CHECK-SAME:         outs(%[[FILL0_TILE]] :
50//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
51//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
52//      CHECK:     %[[FILL1_TILE:.+]] = linalg.fill
53// CHECK-SAME:         outs(%[[INIT1_TILE]] :
54//      CHECK:     %[[GEMM1_TILE:.+]] = linalg.matmul
55// CHECK-SAME:         ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
56// CHECK-SAME:         outs(%[[FILL1_TILE]] :
57//      CHECK:     %[[INSERT0:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0]
58//      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
59//      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]]
60//      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#0
61
62// -----
63
64func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
65                       %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
66                       %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
67                       -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
68  %out0, %out1 = linalg.generic {
69    indexing_maps = [affine_map<(i, j) -> (i, j)>,
70                     affine_map<(i, j) -> (i, j)>,
71                     affine_map<(i, j) -> (i, j)>,
72                     affine_map<(i, j) -> (j, i)>],
73    iterator_types = ["parallel", "parallel"]
74  }
75  ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
76  outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
77  ^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
78    %4 = arith.mulf %0, %1 : f32
79    %5 = arith.addf %0, %1 : f32
80    linalg.yield %4, %5: f32, f32
81  } -> (tensor<32x32xf32>, tensor<32x32xf32>)
82
83  %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
84
85  return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
86}
87
88module attributes {transform.with_named_sequence} {
89  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
90    %add = transform.structured.match ops{["linalg.add"]} in %arg0
91      : (!transform.any_op) -> !transform.any_op
92    %a, %b = transform.test.fuse_and_yield %add [16]
93      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
94    transform.yield
95  }
96}
97//      CHECK: func.func @multiple_outputs_fusion_yield_all(
98// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
99// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
100// CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
101// CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
102// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
103// CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
104//      CHECK:   %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
105// CHECK-SAME:       iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
106//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
107//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
108//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
109//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
110//      CHECK:     %[[GENERIC_TILE:.+]]:2 = linalg.generic
111// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
112// CHECK-SAME:         outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
113//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
114//  CHECK-DAG:     %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
115//      CHECK:     %[[ADD_TILE:.+]] = linalg.add
116// CHECK-SAME:         ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
117// CHECK-SAME:         outs(%[[INIT2_TILE]] :
118//      CHECK:     %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
119//      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
120//      CHECK:     %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
121//      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
122//      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0
123