1// RUN: mlir-opt -split-input-file \ 2// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-unpack.mlir' \ 3// RUN: -transform-interpreter=entry-point=decompose_unpack %s | FileCheck %s 4 5func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> { 6 %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32> 7 return %0 : tensor<1x1x32x8xf32> 8} 9// CHECK-LABEL: func.func @simple_KCRSsr_to_KCRS 10// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 11// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 12// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] 13// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> 14// CHECK: %[[TRANSP:.+]] = linalg.transpose 15// CHECK-SAME: ins(%[[TILE]] : tensor<8x32xf32>) 16// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x8xf32>) 17// CHECK-SAME: permutation = [1, 0] 18// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] 19// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] 20// CHECK: return %[[INSERT]] 21 22// ----- 23 24func.func @simple_unpack_static_tiles(%input: tensor<1x1x8x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> { 25 %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<1x1x8x2xf32> -> tensor<5x1xf32> 26 return %0 : tensor<5x1xf32> 27} 28// CHECK-LABEL: func.func @simple_unpack_static_tiles 29// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 30// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 31// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] 32// CHECK-NOT: linalg.transpose 33// They have the same type, so the insert_slice op is folded 34// away. 35// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1] 36// CHECK: return %[[SLICE]] 37 38/// Same as example above, but with 1 dynamic tile size. 39 40func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> { 41 %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32> 42 return %0 : tensor<5x1xf32> 43} 44// CHECK-LABEL: func.func @simple_unpack_dynamic_tile 45// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 46// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 47// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]] 48// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM]], 2] [1, 1, 1, 1] 49// CHECK-NOT: linalg.transpose 50// They have the same type, so the insert_slice op is folded 51// away. 52// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1] 53// CHECK: return %[[SLICE]] 54 55/// Same as example above, but with 1 dynamic tile size and a trasnpose 56 57func.func @simple_unpack_dynamic_tile_transpose(%src: tensor<1x1x2x?xf32>, %dest: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> { 58 %0 = tensor.unpack %src inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim] into %dest : tensor<1x1x2x?xf32> -> tensor<5x1xf32> 59 return %0 : tensor<5x1xf32> 60} 61// CHECK-LABEL: func.func @simple_unpack_dynamic_tile_transpose 62// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 63// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 64// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]] 65// CHECK: %[[TILE:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM]]] [1, 1, 1, 1] : tensor<1x1x2x?xf32> to tensor<2x?xf32> 66// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM]]) : tensor<?x2xf32> 67// CHECK: %[[TRANSP:.*]] = linalg.transpose 68// CHECK-SAME: ins(%[[TILE]] : tensor<2x?xf32>) 69// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x2xf32>) 70// CHECK-SAME: permutation = [1, 0] 71// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1] : tensor<?x2xf32> to tensor<5x1xf32> 72// CHECK: return %[[SLICE]] : tensor<5x1xf32> 73 74 75/// Same as example above, but with 1 scalable tile size. 76 77func.func @simple_unpack_scalable_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> { 78 %c8 = arith.constant 8 : index 79 %vscale = vector.vscale 80 %c8_vscale = arith.muli %vscale, %c8 : index 81 %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32> 82 return %0 : tensor<5x1xf32> 83} 84// CHECK-LABEL: func.func @simple_unpack_scalable_tile 85// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 86// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 87// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index 88// CHECK-DAG: %[[VS:.+]] = vector.vscale 89// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index 90// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] 91// CHECK-NOT: linalg.transpose 92// They have the same type, so the insert_slice op is folded 93// away. 94// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1] 95// CHECK: return %[[SLICE]] 96 97// ----- 98 99func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32>) -> tensor<32x8xf32>{ 100 %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<1x1x32x8xf32> -> tensor<32x8xf32> 101 return %0 : tensor<32x8xf32> 102} 103// CHECK-LABEL: func.func @simple_CNnc_to_NC 104// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 105// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 106// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] 107// CHECK-NOT: linalg.transpose 108// They have the same type, so the insert_slice op is folded 109// away. 110// CHECK: return %[[TILE]] 111 112// ----- 113 114func.func @simple_NCHWc_to_NCHW(%arg0: tensor<2x1x16x8x32xf32>, %arg1: tensor<2x32x16x8xf32>) -> tensor<2x32x16x8xf32> { 115 %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %arg1 : tensor<2x1x16x8x32xf32> -> tensor<2x32x16x8xf32> 116 return %0 : tensor<2x32x16x8xf32> 117} 118// CHECK-LABEL: func.func @simple_NCHWc_to_NCHW 119// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 120// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 121// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1] 122// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x32x16x8xf32> 123// CHECK: %[[TRANSP:.+]] = linalg.transpose 124// CHECK-SAME: ins(%[[TILE]] : tensor<2x16x8x32xf32>) 125// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x32x16x8xf32>) 126// CHECK-SAME: permutation = [0, 3, 1, 2] 127// They have the same type, so the insert_slice op is folded 128// away. 129// CHECK: return %[[TRANSP]] 130 131// ----- 132 133func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x16x8xf32>) -> tensor<1x32x16x8xf32> { 134 %0 = tensor.unpack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [] inner_tiles = [] into %arg1 : tensor<1x16x8x32xf32> -> tensor<1x32x16x8xf32> 135 return %0 : tensor<1x32x16x8xf32> 136} 137// CHECK-LABEL: func.func @simple_NHWC_to_NCHW 138// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 139// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 140// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 16, 8, 32] [1, 1, 1, 1] 141// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x8xf32> 142// CHECK: %[[TRANSP:.+]] = linalg.transpose 143// CHECK-SAME: ins(%[[TILE]] : tensor<16x8x32xf32>) 144// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x16x8xf32>) 145// CHECK-SAME: permutation = [2, 0, 1] 146// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] 147// CHECK-SAME: [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1] 148// CHECK: return %[[INSERT]] 149 150// ----- 151 152func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tensor<?x1x32x8xf32>) -> tensor<?x1x32x8xf32> { 153 %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<?x1x1x1x8x32xf32> -> tensor<?x1x32x8xf32> 154 return %0 : tensor<?x1x32x8xf32> 155} 156// CHECK-LABEL: func.func @unpack_with_dynamic_dims 157// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 158// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 159// CHECK: %[[C0:.+]] = arith.constant 0 : index 160// CHECK: %[[DIM0_SRC:.+]] = tensor.dim %[[SRC]], %[[C0]] : tensor<?x1x1x1x8x32xf32> 161// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0, 0] [%[[DIM0_SRC]], 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] 162// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0_SRC]]) : tensor<?x32x8xf32> 163// CHECK: %[[TRANSP:.+]] = linalg.transpose 164// CHECK-SAME: ins(%[[TILE]] : tensor<?x8x32xf32>) 165// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x32x8xf32>) 166// CHECK-SAME: permutation = [0, 2, 1] 167// CHECK: %[[DIM0_DEST:.+]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x1x32x8xf32> 168// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0, 0] [%[[DIM0_DEST]], 32, 8] [1, 1, 1] : tensor<?x32x8xf32> to tensor<?x32x8xf32> 169// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]] 170// CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1] 171// CHECK: return %[[INSERT]] 172