1// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s 2 3func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>, 4 %init0 : tensor<?x?xf32>, %init1 : tensor<?x?xf32>) 5 -> (tensor<?x?xf32>, tensor<?x?xf32>) { 6 %c0 = arith.constant 0 : index 7 %c1 = arith.constant 1 : index 8 %cst = arith.constant 0.0 : f32 9 %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32> 10 %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32> 11 %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32> 12 %gemm0 = linalg.matmul 13 ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32> 14 %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32> 15 %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32> 16 %gemm1 = linalg.matmul 17 ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32> 18 return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32> 19} 20 21module attributes {transform.with_named_sequence} { 22 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 23 %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1 24 : (!transform.any_op) -> !transform.any_op 25 %mm1, %mm2 = transform.split_handle %matmuls 26 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 27 %a, %b = transform.test.fuse_and_yield %mm2 [10] 28 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 29 transform.yield 30 } 31} 32// CHECK: func.func @gemm_gemm_fusion_yield_both( 33// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 34// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>, 35// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>, 36// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<?x?xf32>, 37// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>) 38// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 39// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 40// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] = 41// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]]) 42// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] 43// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0] 44// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0] 45// CHECK: %[[FILL0_TILE:.+]] = linalg.fill 46// CHECK-SAME: outs(%[[INIT0_TILE]] : 47// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul 48// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : 49// CHECK-SAME: outs(%[[FILL0_TILE]] : 50// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] 51// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] 52// CHECK: %[[FILL1_TILE:.+]] = linalg.fill 53// CHECK-SAME: outs(%[[INIT1_TILE]] : 54// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul 55// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : 56// CHECK-SAME: outs(%[[FILL1_TILE]] : 57// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0] 58// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0] 59// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]] 60// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0 61 62// ----- 63 64func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>, 65 %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>, 66 %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>) 67 -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) { 68 %out0, %out1 = linalg.generic { 69 indexing_maps = [affine_map<(i, j) -> (i, j)>, 70 affine_map<(i, j) -> (i, j)>, 71 affine_map<(i, j) -> (i, j)>, 72 affine_map<(i, j) -> (j, i)>], 73 iterator_types = ["parallel", "parallel"] 74 } 75 ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>) 76 outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) { 77 ^bb0(%0: f32, %1: f32, %2: f32, %3: f32): 78 %4 = arith.mulf %0, %1 : f32 79 %5 = arith.addf %0, %1 : f32 80 linalg.yield %4, %5: f32, f32 81 } -> (tensor<32x32xf32>, tensor<32x32xf32>) 82 83 %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32> 84 85 return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32> 86} 87 88module attributes {transform.with_named_sequence} { 89 transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { 90 %add = transform.structured.match ops{["linalg.add"]} in %arg0 91 : (!transform.any_op) -> !transform.any_op 92 %a, %b = transform.test.fuse_and_yield %add [16] 93 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 94 transform.yield 95 } 96} 97// CHECK: func.func @multiple_outputs_fusion_yield_all( 98// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32> 99// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>, 100// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>, 101// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>, 102// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>, 103// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>) 104// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] = 105// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]]) 106// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] 107// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0] 108// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0] 109// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]] 110// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic 111// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : 112// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] : 113// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0] 114// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] 115// CHECK: %[[ADD_TILE:.+]] = linalg.add 116// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] : 117// CHECK-SAME: outs(%[[INIT2_TILE]] : 118// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0] 119// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0] 120// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]] 121// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]] 122// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0 123