1// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s 2 3func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> { 4 %c0 = arith.constant 0 : index 5 %c1 = arith.constant 1 : index 6 %cst = arith.constant 0.0 : f32 7 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 8 %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> 9 %init = tensor.empty(%d0, %d1) : tensor<?x?xf32> 10 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32> 11 %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 12 outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32> 13 return %gemm : tensor<?x?xf32> 14} 15 16module attributes {transform.with_named_sequence} { 17 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 18 %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 19 : (!transform.any_op) -> !transform.any_op 20 %a, %b, %c = transform.structured.fuse %matmul [10, 20] 21 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 22 transform.yield 23 } 24} 25// CHECK: func.func @gemm_fill_fusion( 26// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 27// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>) 28// CHECK: %[[INIT:.+]] = tensor.empty 29// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = 30// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) 31// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = 32// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) 33// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] 34// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] 35// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] 36// CHECK: %[[FILL_TILE:.+]] = linalg.fill 37// CHECK-SAME: outs(%[[INIT_TILE]] : 38// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul 39// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : 40// CHECK-SAME: outs(%[[FILL_TILE]] : 41// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] 42// CHECK: scf.yield %[[INSERT]] 43 44// ----- 45 46func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, 47 %arg2 : tensor<?xf32>) -> tensor<?x?xf32> { 48 %c0 = arith.constant 0 : index 49 %c1 = arith.constant 1 : index 50 %cst = arith.constant 0.0 : f32 51 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 52 %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> 53 %init = tensor.empty(%d0, %d1) : tensor<?x?xf32> 54 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32> 55 %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 56 outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32> 57 %generic = linalg.generic { 58 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], 59 iterator_types = ["parallel", "parallel"]} 60 ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) { 61 ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): 62 %add = arith.addf %b0, %b1 : f32 63 linalg.yield %add : f32 64 } -> tensor<?x?xf32> 65 return %generic : tensor<?x?xf32> 66} 67 68module attributes {transform.with_named_sequence} { 69 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 70 %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 71 : (!transform.any_op) -> !transform.any_op 72 %a, %b, %c = transform.structured.fuse %generic [10, 20] 73 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 74 transform.yield 75 } 76} 77// CHECK: func.func @gemm_generic_fusion( 78// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 79// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>, 80// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>) 81// CHECK: %[[INIT:.+]] = tensor.empty 82// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = 83// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) 84// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = 85// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) 86// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] 87// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] 88// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] 89// CHECK: %[[FILL_TILE:.+]] = linalg.fill 90// CHECK-SAME: outs(%[[INIT_TILE]] : 91// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul 92// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : 93// CHECK-SAME: outs(%[[FILL_TILE]] : 94// CHECK-DAG: %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]] 95// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] 96// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic 97// CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] : 98// CHECK-SAME: outs(%[[OUTS_TILE]] : 99// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] 100// CHECK: scf.yield %[[INSERT]] 101 102// ----- 103 104func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>) -> tensor<?x?xf32> { 105 %c0 = arith.constant 0 : index 106 %c1 = arith.constant 1 : index 107 %cst = arith.constant 0.0 : f32 108 %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32> 109 %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32> 110 %init0 = tensor.empty(%d0, %d1) : tensor<?x?xf32> 111 %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32> 112 %gemm0 = linalg.matmul 113 ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32> 114 %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32> 115 %init1 = tensor.empty(%d0, %d2) : tensor<?x?xf32> 116 %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32> 117 %gemm1 = linalg.matmul 118 ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32> 119 return %gemm1 : tensor<?x?xf32> 120} 121 122module attributes {transform.with_named_sequence} { 123 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 124 %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1 125 : (!transform.any_op) -> !transform.any_op 126 %mm1, %mm2 = transform.split_handle %matmuls 127 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 128 %a, %b = transform.structured.fuse %mm2 [10] 129 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 130 transform.yield 131 } 132} 133// CHECK: func.func @gemm_gemm_fusion( 134// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 135// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>, 136// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>) 137// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 138// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 139// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]] 140// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]] 141// CHECK-DAG: %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]]) 142// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]] 143// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D2]]) 144// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] = 145// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]]) 146// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] 147// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0] 148// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]], 0] 149// CHECK: %[[FILL0_TILE:.+]] = linalg.fill 150// CHECK-SAME: outs(%[[INIT0_TILE]] : 151// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul 152// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : 153// CHECK-SAME: outs(%[[FILL0_TILE]] : 154// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] 155// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0] 156// CHECK: %[[FILL1_TILE:.+]] = linalg.fill 157// CHECK-SAME: outs(%[[INIT1_TILE]] : 158// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul 159// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : 160// CHECK-SAME: outs(%[[FILL1_TILE]] : 161// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0] 162// CHECK: scf.yield %[[INSERT]] 163 164// ----- 165 166func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> { 167 %c0 = arith.constant 0 : index 168 %c1 = arith.constant 1 : index 169 %cst = arith.constant 0.0 : f32 170 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 171 %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> 172 %init0 = tensor.empty(%d0, %d1) : tensor<?x?xf32> 173 %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32> 174 %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 175 outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32> 176 %init1 = tensor.empty(%d1, %d0) : tensor<?x?xf32> 177 %transpose = linalg.generic { 178 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], 179 iterator_types = ["parallel", "parallel"]} 180 ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) { 181 ^bb0(%b0 : f32, %b1 : f32): 182 linalg.yield %b0 : f32 183 } -> tensor<?x?xf32> 184 return %transpose : tensor<?x?xf32> 185} 186 187module attributes {transform.with_named_sequence} { 188 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 189 %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 190 : (!transform.any_op) -> !transform.any_op 191 %a, %b, %c = transform.structured.fuse %generic [10, 20] 192 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 193 transform.yield 194 } 195} 196// CHECK: func.func @gemm_transpose_fusion( 197// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 198// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>) 199// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 200// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 201// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] 202// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]] 203// CHECK-DAG: %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]]) 204// CHECK-DAG: %[[INIT1:.+]] = tensor.empty(%[[D1]], %[[D0]]) 205// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = 206// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT1]]) 207// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = 208// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) 209// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] 210// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] 211// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]]] 212// CHECK: %[[FILL_TILE:.+]] = linalg.fill 213// CHECK-SAME: outs(%[[INIT0_TILE]] : 214// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul 215// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : 216// CHECK-SAME: outs(%[[FILL_TILE]] : 217// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]] 218// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic 219// CHECK-SAME: ins(%[[GEMM_TILE]] : 220// CHECK-SAME: outs(%[[OUTS_TILE]] : 221// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] 222// CHECK: scf.yield %[[INSERT]] 223 224// ----- 225 226func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> { 227 %c0 = arith.constant 0 : index 228 %c1 = arith.constant 1 : index 229 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 230 %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> 231 %cst = arith.constant 0.0 : f32 232 %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32> 233 %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32> 234 %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 235 outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32> 236 %3 = linalg.generic { 237 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], 238 iterator_types = ["parallel", "parallel"]} 239 ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) { 240 ^bb0(%b0 : f32, %b1 : f32): 241 %4 = arith.addf %b0, %b0 : f32 242 linalg.yield %4 : f32 243 } -> tensor<?x?xf32> 244 return %3 : tensor<?x?xf32> 245} 246 247module attributes {transform.with_named_sequence} { 248 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 249 %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 250 : (!transform.any_op) -> !transform.any_op 251 %a, %b, %c = transform.structured.fuse %generic [10, 20] interchange[1, 0] 252 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 253 transform.yield 254 } 255} 256// CHECK: func.func @interchange_matmul_fusion( 257// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 258// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>) 259// CHECK: %[[INIT:.+]] = tensor.empty 260// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = 261// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) 262// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = 263// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) 264// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] 265// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]] 266// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]] 267// CHECK: %[[FILL_TILE:.+]] = linalg.fill 268// CHECK-SAME: outs(%[[INIT_TILE]] : 269// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul 270// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : 271// CHECK-SAME: outs(%[[FILL_TILE]] : 272// CHECK: %[[INIT_TILE_2:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]] 273// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic 274// CHECK-SAME: ins(%[[GEMM_TILE]] : 275// CHECK-SAME: outs(%[[INIT_TILE_2]] : 276// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] 277// CHECK: scf.yield %[[INSERT]] 278 279// ----- 280 281func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, 282 %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{ 283 %c0 = arith.constant 0 : index 284 %c1 = arith.constant 1 : index 285 %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32> 286 %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32> 287 %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 288 outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> 289 %3 = tensor.dim %2, %c0 : tensor<?x?xf32> 290 %4 = tensor.dim %2, %c1 : tensor<?x?xf32> 291 %5 = tensor.empty(%3, %4) : tensor<?x?xf32> 292 %6 = linalg.generic 293 {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 294 affine_map<(d0, d1) -> (d0, d1)>, 295 affine_map<(d0, d1) -> (d0, d1)>], 296 iterator_types = ["parallel", "parallel"]} 297 ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>) 298 outs(%5 : tensor<?x?xf32>) { 299 ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : 300 %7 = arith.addf %arg3, %arg4 : f32 301 linalg.yield %7 : f32 302 } -> tensor<?x?xf32> 303 return %6 : tensor<?x?xf32> 304} 305 306module attributes {transform.with_named_sequence} { 307 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 308 %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 309 : (!transform.any_op) -> !transform.any_op 310 %a, %b, %c = transform.structured.fuse %generic [10, 20] 311 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 312 transform.yield 313 } 314} 315// CHECK: func @matmul_plus_matmul 316// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 317// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 318// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 319// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] 320// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) 321// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] 322// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) 323// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] 324// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] 325// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] 326// CHECK: %[[MATMUL:.+]] = linalg.matmul 327// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : 328// CHECK-SAME: outs(%[[ST_ARG2]] : 329// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] 330// CHECK: %[[ST_RESULT:.+]] = linalg.generic 331// CHECK-SAME: ins(%[[MATMUL]], %[[MATMUL]] : 332// CHECK-SAME: outs(%[[ST_ARG6]] : 333// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] 334// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] 335// CHECK: scf.yield %[[UPDATE]] 336// CHECK: scf.yield %[[YIELD]] 337// CHECK: return %[[RESULT]] 338 339// ----- 340 341func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, 342 %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{ 343 %c0 = arith.constant 0 : index 344 %c1 = arith.constant 1 : index 345 %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32> 346 %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32> 347 %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 348 outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> 349 %3 = tensor.dim %2, %c0 : tensor<?x?xf32> 350 %4 = tensor.dim %2, %c1 : tensor<?x?xf32> 351 %5 = tensor.empty(%3, %4) : tensor<?x?xf32> 352 %6 = linalg.generic 353 {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 354 affine_map<(d0, d1) -> (d1, d0)>, 355 affine_map<(d0, d1) -> (d0, d1)>], 356 iterator_types = ["parallel", "parallel"]} 357 ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>) 358 outs(%5 : tensor<?x?xf32>) { 359 ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : 360 %7 = arith.addf %arg3, %arg4 : f32 361 linalg.yield %7 : f32 362 } -> tensor<?x?xf32> 363 return %6 : tensor<?x?xf32> 364} 365 366module attributes {transform.with_named_sequence} { 367 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 368 %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 369 : (!transform.any_op) -> !transform.any_op 370 %a, %b, %c = transform.structured.fuse %generic [10, 20] 371 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 372 transform.yield 373 } 374} 375// CHECK: func @matmul_plus_transpose_matmul 376// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 377// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 378// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 379// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] 380// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) 381// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] 382// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) 383// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] 384// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] 385// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] 386// CHECK: %[[LHS:.+]] = linalg.matmul 387// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] 388// CHECK-SAME: : tensor<?x?xf32>, tensor<?x?xf32>) 389// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>) 390// CHECK-DAG: %[[STR_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] 391// CHECK-DAG: %[[STR_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]] 392// CHECK-DAG: %[[STR_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV0]]] 393// CHECK: %[[RHS:.+]] = linalg.matmul 394// CHECK-SAME: ins(%[[STR_ARG0]], %[[STR_ARG1]] : 395// CHECK-SAME: outs(%[[STR_ARG2]] : 396// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] 397// CHECK: %[[ST_RESULT:.+]] = linalg.generic 398// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : 399// CHECK-SAME: outs(%[[ST_ARG6]] : 400// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] 401// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] 402// CHECK: scf.yield %[[UPDATE]] 403// CHECK: scf.yield %[[YIELD]] 404// CHECK: return %[[RESULT]] 405 406// ----- 407 408func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, 409 %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>, 410 %arg5: tensor<?x?xf32>, %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> { 411 %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 412 outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1] 413 %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>) 414 outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2] 415 %2 = linalg.matmul ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>) 416 outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3] 417 return %2 : tensor<?x?xf32> 418} 419 420module attributes {transform.with_named_sequence} { 421 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 422 %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1 423 : (!transform.any_op) -> !transform.any_op 424 %mm1, %mm2, %mm3 = transform.split_handle %matmuls 425 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 426 %a, %b = transform.structured.fuse %mm3 [10] 427 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 428 transform.yield 429 } 430} 431// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)> 432// CHECK: func @matmul_sequence_fusion( 433// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 434// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 435// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 436// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 437// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 438// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 439// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> { 440// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 441// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 442// CHECK-DAG: %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : 443// CHECK-DAG: %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] : 444// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]] 445// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]] 446// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]] 447// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] = 448// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) { 449// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]] 450// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]] 451// CHECK-DAG: %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%[[M]]] 452// CHECK-DAG: %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]] 453// CHECK-DAG: %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] 454// CHECK-DAG: %[[SLICE_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV]], 0] [%[[TILE_M]], %[[N1]]] 455// CHECK-DAG: %[[TILE_GEMM1:.+]] = linalg.matmul ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] : 456// CHECK-SAME: outs(%[[SLICE_ARG2]] : 457// CHECK-DAG: %[[SLICE_ARG3:.+]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[N1]], %[[N2]]] 458// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]] 459// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] : 460// CHECK-SAME: outs(%[[SLICE_ARG4]] : 461// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]] 462// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] 463// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul 464// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] : 465// CHECK-SAME: outs(%[[SLICE_ARG6]] : 466// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] 467// CHECK: scf.yield %[[UPDATE]] 468 469// ----- 470 471func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> { 472 %cst = arith.constant 0.000000e+00 : f32 473 %cst_0 = arith.constant 0xFF800000 : f32 474 %0 = tensor.empty() : tensor<30xf32> 475 %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32> 476 %2 = linalg.generic { 477 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], 478 iterator_types = ["parallel", "reduction"]} 479 ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) { 480 ^bb0(%arg1: f32, %arg2: f32): 481 %8 = arith.maximumf %arg2, %arg1 : f32 482 linalg.yield %8 : f32 483 } -> tensor<30xf32> 484 %3 = tensor.empty() : tensor<30x3xf32> 485 %4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32> 486 %5:2 = linalg.generic { 487 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, 488 affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], 489 iterator_types = ["parallel", "reduction"]} 490 ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) { 491 ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): 492 %8 = arith.subf %arg1, %arg2 : f32 493 %9 = math.exp %8 : f32 494 %10 = arith.addf %arg3, %9 : f32 495 linalg.yield %10, %9 : f32, f32 496 } -> (tensor<30xf32>, tensor<30x3xf32>) 497 %6 = linalg.generic { 498 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, 499 affine_map<(d0, d1) -> (d0, d1)>], 500 iterator_types = ["parallel", "parallel"]} 501 ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) { 502 ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): 503 %8 = arith.divf %arg1, %arg2 : f32 504 linalg.yield %8 : f32 505 } -> tensor<30x3xf32> 506 return %6 : tensor<30x3xf32> 507} 508 509module attributes {transform.with_named_sequence} { 510 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 511 %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 512 : (!transform.any_op) -> !transform.any_op 513 %generic1, %generic2, %generic3 = transform.split_handle %generics 514 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 515 %a, %b = transform.structured.fuse %generic3 [10] 516 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 517 transform.yield 518 } 519} 520// CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>) 521// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<30xf32> 522// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32> 523// CHECK: %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV:[a-zA-Z0-9]+]] 524// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]]) 525// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] 526// CHECK-DAG: %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]] 527// CHECK: %[[FILL0:.+]] = linalg.fill 528// CHECK-SAME: outs(%[[INIT0_SLICE]] : 529// CHECK: %[[GENERIC0:.+]] = linalg.generic 530// CHECK-SAME: ins(%[[ARG0_SLICE]] : 531// CHECK-SAME: outs(%[[FILL0]] : 532// CHECK: %[[FILL1:.+]] = linalg.fill 533// CHECK-SAME: outs(%[[INIT0_SLICE]] : 534// CHECK: %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] 535// CHECK: %[[GENERIC1:.+]]:2 = linalg.generic 536// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[GENERIC0]] : 537// CHECK-SAME: outs(%[[FILL1]], %[[INIT1_SLICE]] : 538// CHECK: %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] 539// CHECK: %[[GENERIC2:.+]] = linalg.generic 540// CHECK-SAME: ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 : 541// CHECK-SAME: outs(%[[ITERARG0_SLICE]] : 542// CHECK-DAG: %[[INSERTSLICE:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0] 543// CHECK: scf.yield %[[INSERTSLICE]] 544// CHECK: return %[[RESULT]] 545 546// ----- 547 548func.func @pad_producer_fusion(%arg0 : tensor<10xf32>) -> tensor<16xf32> { 549 %0 = tensor.empty() : tensor<10xf32> 550 %1 = linalg.generic { 551 indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], 552 iterator_types = ["parallel"]} 553 ins(%arg0 : tensor<10xf32>) outs(%0 : tensor<10xf32>) { 554 ^bb0(%b0 : f32, %b1 : f32): 555 %2 = arith.addf %b0, %b0: f32 556 linalg.yield %2 : f32 557 } -> tensor<10xf32> 558 %cst = arith.constant 0.0 : f32 559 %2 = tensor.pad %1 low[4] high[2] { 560 ^bb0(%arg1 : index): 561 tensor.yield %cst : f32 562 } : tensor<10xf32> to tensor<16xf32> 563 return %2 : tensor<16xf32> 564} 565module attributes {transform.with_named_sequence} { 566 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 567 %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 568 : (!transform.any_op) -> !transform.any_op 569 %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 570 : (!transform.any_op) -> !transform.any_op 571 %a, %b = transform.structured.fuse %pad [8] 572 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 573 transform.yield 574 } 575} 576// CHECK-LABEL: func @pad_producer_fusion 577// CHECK-SAME: %[[ARG0:.+]]: tensor<10xf32> 578// CHECK: %[[FOR_RESULT:.+]] = scf.for 579// CHECK: %[[IF_RESULT:.+]] = scf.if 580// CHECK: else 581// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] 582// CHECK: %[[GENERIC:.+]] = linalg.generic 583// CHECK-SAME: ins(%[[SLICE]] : 584// CHECK: %[[PAD:.+]] = tensor.pad %[[GENERIC]] 585// CHECK: %[[CAST:.+]] = tensor.cast %[[PAD]] 586// CHECK: scf.yield %[[CAST]] 587// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]] 588// CHECK: scf.yield %[[INSERT_SLICE]] 589// CHECK: return %[[FOR_RESULT]] 590 591// ----- 592 593func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> { 594 %0 = tensor.unpack %source 595 outer_dims_perm = [0, 1, 2] 596 inner_dims_pos = [1, 2] 597 inner_tiles = [8, 4] into %dest 598 : tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32> 599 %1 = tensor.empty() : tensor<1x2x1152xf32> 600 %cst = arith.constant 1.0 : f32 601 %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 602 affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 603 iterator_types = ["parallel", "parallel", "parallel"]} 604 ins(%0 : tensor<1x2x1152xf32>) 605 outs(%1 : tensor<1x2x1152xf32>) { 606 ^bb0(%in: f32, %out: f32): 607 %7 = arith.addf %in, %cst : f32 608 linalg.yield %7 : f32 609 } -> tensor<1x2x1152xf32> 610 return %2 : tensor<1x2x1152xf32> 611} 612 613module attributes {transform.with_named_sequence} { 614 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { 615 %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1 616 : (!transform.any_op) -> !transform.any_op 617 %a, %b = transform.structured.fuse %matmul [0, 1, 0] 618 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 619 transform.yield 620 } 621} 622 623// CHECK-LABEL: func @imperfect_unpack_producer_fusion 624// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x288x8x4xf32> 625// CHECK-SAME: %[[ARG1:.+]]: tensor<1x2x1152xf32> 626// CHECK: %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}}) 627// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] 628// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SLICE]] 629// CHECK-DAG: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]] 630// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]] 631// CHECK: %[[GENERIC:.+]] = linalg.generic 632// CHECK-SAME: ins(%[[UNPACK_SLICE]] 633// CHECK-SAME: outs(%[[INIT_SLICE]] 634// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]] 635// CHECK: scf.yield %[[INSERT_SLICE]] 636// CHECK: return %[[FOR_RESULT]] 637