1// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s 2 3#map = affine_map<(d0, d1) -> (d0, d1)> 4func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { 5 %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { 6 ^bb0(%in: f32, %in_0: f32, %out: f32): 7 %1 = arith.addf %in, %in_0 : f32 8 linalg.yield %1 : f32 9 } -> tensor<?x?xf32> 10 return %0 : tensor<?x?xf32> 11} 12// CHECK-LABEL: specialize_add 13// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> 14// CHECK-NOT: linalg.generic 15// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> 16 17func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { 18 %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { 19 ^bb0(%in: f32, %in_0: f32, %out: f32): 20 %1 = arith.subf %in, %in_0 : f32 21 linalg.yield %1 : f32 22 } -> tensor<?x?xf32> 23 return %0 : tensor<?x?xf32> 24} 25// CHECK-LABEL: specialize_sub 26// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> 27// CHECK-NOT: linalg.generic 28// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> 29 30func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { 31 %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { 32 ^bb0(%in: f32, %in_0: f32, %out: f32): 33 %1 = arith.subf %in_0, %in : f32 34 linalg.yield %1 : f32 35 } -> tensor<?x?xf32> 36 return %0 : tensor<?x?xf32> 37} 38// CHECK-LABEL: specialize_sub 39// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> 40// CHECK-NOT: linalg.generic 41// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> 42 43func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { 44 %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { 45 ^bb0(%in: f32, %in_0: f32, %out: f32): 46 %1 = arith.mulf %in, %in_0 : f32 47 linalg.yield %1 : f32 48 } -> tensor<?x?xf32> 49 return %0 : tensor<?x?xf32> 50} 51// CHECK-LABEL: specialize_mul 52// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> 53// CHECK-NOT: linalg.generic 54// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> 55 56func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { 57 %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { 58 ^bb0(%in: f32, %in_0: f32, %out: f32): 59 %1 = arith.divf %in, %in_0 : f32 60 linalg.yield %1 : f32 61 } -> tensor<?x?xf32> 62 return %0 : tensor<?x?xf32> 63} 64// CHECK-LABEL: specialize_div 65// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> 66// CHECK-NOT: linalg.generic 67// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> 68 69 70module attributes {transform.with_named_sequence} { 71 transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { 72 %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op 73 %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op 74 transform.yield 75 } 76} 77