1// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s 2 3func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>, 4 %arg2 : index, %arg3 : index) -> tensor<?x?xf32> { 5 %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 6 : tensor<?x?x8x4xf32> -> tensor<?x?xf32> 7 %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 8 return %1 : tensor<?x?xf32> 9} 10// CHECK: func @fold_unpack_slice( 11// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32> 12// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32> 13// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index 14// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index 15// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32> 16// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] 17// CHECK-SAME: into %[[INIT]] 18// CHECK: return %[[UNPACK]] 19 20// ----- 21 22func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>, 23 %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> { 24 %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 25 : tensor<?x?x8x4xf32> -> tensor<?x?xf32> 26 %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 27 return %1 : tensor<?x?xf32> 28} 29// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset( 30// CHECK: %[[UNPACK:.+]] = tensor.unpack 31// CHECK: tensor.extract_slice %[[UNPACK]] 32 33// ----- 34 35func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>, 36 %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> { 37 %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 38 : tensor<?x?x8x4xf32> -> tensor<?x?xf32> 39 %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor<?x?xf32> to tensor<?x?xf32> 40 return %1 : tensor<?x?xf32> 41} 42// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride( 43// CHECK: %[[UNPACK:.+]] = tensor.unpack 44// CHECK: tensor.extract_slice %[[UNPACK]] 45 46// ----- 47 48func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>, 49 %arg2 : index, %arg3 : index) -> tensor<f32> { 50 %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 51 : tensor<?x?x8x4xf32> -> tensor<?x?xf32> 52 %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<?x?xf32> to tensor<f32> 53 return %1 : tensor<f32> 54} 55// CHECK-LABEL: func @nofold_unpack_slice_rank_reduced( 56// CHECK: %[[UNPACK:.+]] = tensor.unpack 57// CHECK: tensor.extract_slice %[[UNPACK]] 58 59// ----- 60 61func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { 62 %c0 = arith.constant 0 : index 63 %cst = arith.constant 0.000000e+00 : f32 64 %padded = tensor.pad %src low[0, 0] high[15, 0] { 65 ^bb0(%arg0: index, %arg1: index): 66 tensor.yield %cst : f32 67 } : tensor<16641x16xf32> to tensor<16656x16xf32> 68 %empty = tensor.empty() : tensor<2082x1x8x32xf32> 69 %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty 70 : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32> 71 return %pack : tensor<2082x1x8x32xf32> 72} 73// CHECK-LABEL: func.func @pad_pack 74// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] 75// CHECK: %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32 76// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32> 77// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] 78// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) 79// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]] 80 81// ----- 82 83func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { 84 %c0 = arith.constant 0 : index 85 %cst = arith.constant 0.000000e+00 : f32 86 %padded = tensor.pad %src nofold low[0, 0] high[15, 0] { 87 ^bb0(%arg0: index, %arg1: index): 88 tensor.yield %cst : f32 89 } : tensor<16641x16xf32> to tensor<16656x16xf32> 90 %empty = tensor.empty() : tensor<2082x1x8x32xf32> 91 %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty 92 : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32> 93 return %pack : tensor<2082x1x8x32xf32> 94} 95// CHECK-LABEL: func.func @nofold_pad_pack 96// CHECK: tensor.pad 97// CHECK: tensor.pack 98 99// ----- 100 101func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { 102 %c0 = arith.constant 0 : index 103 %cst0 = arith.constant 0.000000e+00 : f32 104 %cst1 = arith.constant 1.000000e+00 : f32 105 %padded = tensor.pad %src low[0, 0] high[15, 0] { 106 ^bb0(%arg0: index, %arg1: index): 107 tensor.yield %cst0 : f32 108 } : tensor<16641x16xf32> to tensor<16656x16xf32> 109 %empty = tensor.empty() : tensor<2082x1x8x32xf32> 110 %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty 111 : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32> 112 return %pack : tensor<2082x1x8x32xf32> 113} 114// CHECK-LABEL: func.func @pad_pack_different_padding_value 115// CHECK: tensor.pad 116// CHECK: tensor.pack 117 118// ----- 119 120func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> { 121 %0 = tensor.empty() : tensor<56x2x1x57x32xf32> 122 %pack = tensor.pack %arg0 123 outer_dims_perm = [0, 3, 2, 1] 124 inner_dims_pos = [3] 125 inner_tiles = [32] 126 into %0 : tensor<56x57x1x64xf32> -> tensor<56x2x1x57x32xf32> 127 128 %1 = tensor.empty() : tensor<1x57x56x2x32xf32> 129 %transposed = linalg.transpose 130 ins(%pack : tensor<56x2x1x57x32xf32>) 131 outs(%1 : tensor<1x57x56x2x32xf32>) 132 permutation = [2, 3, 0, 1, 4] 133 return %transposed : tensor<1x57x56x2x32xf32> 134} 135// CHECK: func @tensor_pack_linalg_transpose_fold( 136// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>) 137// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32> 138// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 139// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3] 140// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32] 141// CHECK-SAME: into %[[INIT]] 142// CHECK: return %[[PACK]] 143 144// ----- 145 146func.func @tensor_pack_linalg_transpose_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> { 147 %0 = tensor.empty() : tensor<56x2x1x57x32xf32> 148 %pack = tensor.pack %arg0 padding_value(%padding : f32) 149 outer_dims_perm = [0, 3, 2, 1] 150 inner_dims_pos = [3] 151 inner_tiles = [32] 152 into %0 : tensor<56x57x1x55xf32> -> tensor<56x2x1x57x32xf32> 153 154 %1 = tensor.empty() : tensor<1x57x56x2x32xf32> 155 %transposed = linalg.transpose 156 ins(%pack : tensor<56x2x1x57x32xf32>) 157 outs(%1 : tensor<1x57x56x2x32xf32>) 158 permutation = [2, 3, 0, 1, 4] 159 return %transposed : tensor<1x57x56x2x32xf32> 160} 161// CHECK: func @tensor_pack_linalg_transpose_fold_with_padding( 162// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x55xf32>, %[[PADDING:.+]]: f32) 163// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32> 164// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32) 165// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3] 166// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32] 167// CHECK-SAME: into %[[INIT]] 168// CHECK: return %[[PACK]] 169 170// ----- 171 172func.func @tensor_pack_linalg_transpose_fold_no_outer_dims_perm(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> { 173 %0 = tensor.empty() : tensor<56x57x1x2x32xf32> 174 %pack = tensor.pack %arg0 175 inner_dims_pos = [3] 176 inner_tiles = [32] 177 into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32> 178 179 %1 = tensor.empty() : tensor<1x2x56x57x32xf32> 180 %transposed = linalg.transpose 181 ins(%pack : tensor<56x57x1x2x32xf32>) 182 outs(%1 : tensor<1x2x56x57x32xf32>) 183 permutation = [2, 3, 0, 1, 4] 184 return %transposed : tensor<1x2x56x57x32xf32> 185} 186// CHECK: func @tensor_pack_linalg_transpose_fold_no_outer_dims_perm( 187// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>) 188// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32> 189// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 190// CHECK-SAME: outer_dims_perm = [2, 3, 0, 1] 191// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32] 192// CHECK-SAME: into %[[INIT]] 193// CHECK: return %[[PACK]] 194 195// ----- 196 197func.func @tensor_pack_linalg_transpose_fold_tile_dims_transpose(%arg0: tensor<56x72x24x128xf32>) -> tensor<12x56x4x9x32x8x2xf32> { 198 %0 = tensor.empty() : tensor<4x9x12x56x8x2x32xf32> 199 %pack = tensor.pack %arg0 200 outer_dims_perm = [3, 1, 2, 0] 201 inner_dims_pos = [1, 2, 3] 202 inner_tiles = [8, 2, 32] 203 into %0 : tensor<56x72x24x128xf32> -> tensor<4x9x12x56x8x2x32xf32> 204 205 %1 = tensor.empty() : tensor<12x56x4x9x32x8x2xf32> 206 %transposed = linalg.transpose 207 ins(%pack : tensor<4x9x12x56x8x2x32xf32>) 208 outs(%1 : tensor<12x56x4x9x32x8x2xf32>) 209 permutation = [2, 3, 0, 1, 6, 4, 5] 210 return %transposed : tensor<12x56x4x9x32x8x2xf32> 211} 212// CHECK: func @tensor_pack_linalg_transpose_fold_tile_dims_transpose( 213// CHECK-SAME: %[[ARG0:.+]]: tensor<56x72x24x128xf32>) 214// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<12x56x4x9x32x8x2xf32> 215// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 216// CHECK-SAME: outer_dims_perm = [2, 0, 3, 1] 217// CHECK-SAME: inner_dims_pos = [3, 1, 2] inner_tiles = [32, 8, 2] 218// CHECK-SAME: into %[[INIT]] 219// CHECK: return %[[PACK]] 220 221// ----- 222 223func.func @tensor_pack_linalg_transpose_fold_tile_dims_outer_dims_transpose(%arg0: tensor<56x72x24x128xf32>) -> tensor<9x56x2x12x32x8x4xf32> { 224 %0 = tensor.empty() : tensor<4x12x9x56x8x2x32xf32> 225 %pack = tensor.pack %arg0 226 outer_dims_perm = [3, 2, 1, 0] 227 inner_dims_pos = [1, 2, 3] 228 inner_tiles = [8, 2, 32] 229 into %0 : tensor<56x72x24x128xf32> -> tensor<4x12x9x56x8x2x32xf32> 230 231 %1 = tensor.empty() : tensor<9x56x2x12x32x8x4xf32> 232 %transposed = linalg.transpose 233 ins(%pack : tensor<4x12x9x56x8x2x32xf32>) 234 outs(%1 : tensor<9x56x2x12x32x8x4xf32>) 235 permutation = [2, 3, 5, 1, 6, 4, 0] 236 return %transposed : tensor<9x56x2x12x32x8x4xf32> 237} 238// CHECK: func @tensor_pack_linalg_transpose_fold_tile_dims_outer_dims_transpose( 239// CHECK-SAME: %[[ARG0:.+]]: tensor<56x72x24x128xf32>) 240// CHECK: tensor.pack 241// CHECK: linalg.transpose 242 243// ----- 244 245func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims(%arg0: tensor<56x?x?x64xf32>) -> tensor<?x?x56x2x32xf32> { 246 %0 = tensor.empty() : tensor<56x2x1x57x32xf32> 247 %pack = tensor.pack %arg0 248 outer_dims_perm = [0, 3, 2, 1] 249 inner_dims_pos = [3] 250 inner_tiles = [32] 251 into %0 : tensor<56x?x?x64xf32> -> tensor<56x2x1x57x32xf32> 252 253 %1 = tensor.empty() : tensor<1x57x56x2x32xf32> 254 %transposed = linalg.transpose 255 ins(%pack : tensor<56x2x1x57x32xf32>) 256 outs(%1 : tensor<1x57x56x2x32xf32>) 257 permutation = [2, 3, 0, 1, 4] 258 259 %return_value = tensor.cast %transposed : tensor<1x57x56x2x32xf32> to tensor<?x?x56x2x32xf32> 260 return %return_value : tensor<?x?x56x2x32xf32> 261} 262// CHECK: func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims( 263// CHECK-SAME: %[[ARG0:.+]]: tensor<56x?x?x64xf32>) 264// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index 265// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index 266// CHECK: %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<56x?x?x64xf32> 267// CHECK: %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<56x?x?x64xf32> 268// CHECK: %[[INIT:.+]] = tensor.empty(%[[dim_0]], %[[dim]]) : tensor<?x?x56x2x32xf32> 269// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 270// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3] 271// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32] 272// CHECK-SAME: into %[[INIT]] 273// CHECK: return %[[PACK]] 274 275// ----- 276 277func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_and_tile_dims(%arg0: tensor<56x?x?x128xf32>) -> tensor<?x?x56x9x32x8x2xf32> { 278 %0 = tensor.empty() : tensor<56x9x12x4x8x2x32xf32> 279 %pack = tensor.pack %arg0 280 inner_dims_pos = [1, 2, 3] 281 inner_tiles = [8, 2, 32] 282 into %0 : tensor<56x?x?x128xf32> -> tensor<56x9x12x4x8x2x32xf32> 283 284 %1 = tensor.empty() : tensor<12x4x56x9x32x8x2xf32> 285 %transposed = linalg.transpose 286 ins(%pack : tensor<56x9x12x4x8x2x32xf32>) 287 outs(%1 : tensor<12x4x56x9x32x8x2xf32>) 288 permutation = [2, 3, 0, 1, 6, 4, 5] 289 290 %return_value = tensor.cast %transposed : tensor<12x4x56x9x32x8x2xf32> to tensor<?x?x56x9x32x8x2xf32> 291 return %return_value : tensor<?x?x56x9x32x8x2xf32> 292} 293// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> 294// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> 295// CHECK-LABEL: func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_and_tile_dims( 296// CHECK-SAME: %[[ARG0:.+]]: tensor<56x?x?x128xf32>) 297// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index 298// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index 299// CHECK: %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<56x?x?x128xf32> 300// CHECK: %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<56x?x?x128xf32> 301// CHECK: %[[mapped_dim1:.+]] = affine.apply #[[$MAP0]]()[%[[dim]]] 302// CHECK: %[[mapped_dim2:.+]] = affine.apply #[[$MAP1]]()[%[[dim_0]]] 303// CHECK: %[[INIT:.+]] = tensor.empty(%[[mapped_dim2]], %[[mapped_dim1]]) : tensor<?x4x56x?x32x8x2xf32> 304// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 3, 0, 1] inner_dims_pos = [3, 1, 2] inner_tiles = [32, 8, 2] into %[[INIT]] : tensor<56x?x?x128xf32> -> tensor<?x4x56x?x32x8x2xf32> 305// CHECK: %[[CAST:.+]] = tensor.cast %[[PACK]] : tensor<?x4x56x?x32x8x2xf32> to tensor<?x?x56x9x32x8x2xf32> 306// CHECK: return %[[CAST]] : tensor<?x?x56x9x32x8x2xf32> 307// CHECK: } 308 309// ----- 310 311func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %pack_dest: tensor<?x?x?x?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?x?x?x?xf32>, %tile_p : index, %tile_q : index, %tile_r : index) -> tensor<?x?x?x?x?x?x?xf32> { 312 %pack = tensor.pack %arg0 313 outer_dims_perm = [3, 0, 2, 1] 314 inner_dims_pos = [1, 2, 3] 315 inner_tiles = [%tile_p, %tile_q, %tile_r] 316 into %pack_dest : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32> 317 318 %transposed = linalg.transpose 319 ins(%pack : tensor<?x?x?x?x?x?x?xf32>) 320 outs(%transpose_dest : tensor<?x?x?x?x?x?x?xf32>) 321 permutation = [2, 3, 0, 1, 6, 4, 5] 322 323 return %transposed : tensor<?x?x?x?x?x?x?xf32> 324} 325// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)> 326// CHECK: module { 327// CHECK: func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_sizes( 328// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, 329// CHECK-SAME: %[[PACK_DEST:.+]]: tensor<?x?x?x?x?x?x?xf32>, %[[TRANSPOSE_DEST:.+]]: tensor<?x?x?x?x?x?x?xf32>, 330// CHECK-SAME: %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, 331// CHECK-SAME: %[[ARG3:.+]]: index) 332// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index 333// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index 334// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index 335// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index 336// CHECK: %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c0]] : tensor<?x?x?x?xf32> 337// CHECK: %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<?x?x?x?xf32> 338// CHECK: %[[dim_1:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<?x?x?x?xf32> 339// CHECK: %[[dim_2:.+]] = tensor.dim %[[ARG0]], %[[c3]] : tensor<?x?x?x?xf32> 340// CHECK: %[[mapped_dim0:.+]] = affine.apply #[[$MAP]]()[%[[dim_2]], %[[ARG3]]] 341// CHECK: %[[mapped_dim1:.+]] = affine.apply #[[$MAP]]()[%[[dim_0]], %[[ARG1]]] 342// CHECK: %[[mapped_dim2:.+]] = affine.apply #[[$MAP]]()[%[[dim_1]], %[[ARG2]]] 343// CHECK: %[[INIT:.+]] = tensor.empty(%[[mapped_dim2]], %[[mapped_dim1]], %[[mapped_dim0]], %[[dim]], %[[ARG3]], %[[ARG1]], %[[ARG2]]) : tensor<?x?x?x?x?x?x?xf32> 344// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1, 2] inner_tiles = [%[[ARG3]], %[[ARG1]], %[[ARG2]]] into %[[INIT]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32> 345// CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32> 346// CHECK: } 347 348// ----- 349 350func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> { 351 %0 = tensor.empty() : tensor<1x56x57x64xf32> 352 %transposed = linalg.transpose 353 ins(%arg0 : tensor<56x57x1x64xf32>) 354 outs(%0 : tensor<1x56x57x64xf32>) 355 permutation = [2, 0, 1, 3] 356 357 %1 = tensor.empty() : tensor<1x57x56x2x32xf32> 358 %pack = tensor.pack %transposed 359 outer_dims_perm = [0, 2, 1, 3] 360 inner_dims_pos = [3] 361 inner_tiles = [32] 362 into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32> 363 return %pack : tensor<1x57x56x2x32xf32> 364} 365//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold( 366// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>) 367// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32> 368// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 369// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3] 370// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32] 371// CHECK-SAME: into %[[INIT]] 372// CHECK: return %[[PACK]] 373 374// ----- 375 376func.func @linalg_transpose_tensor_pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> { 377 %0 = tensor.empty() : tensor<1x56x57x55xf32> 378 %transpose = linalg.transpose 379 ins(%arg0 : tensor<56x57x1x55xf32>) 380 outs(%0 : tensor<1x56x57x55xf32>) 381 permutation = [2, 0, 1, 3] 382 383 %1 = tensor.empty() : tensor<1x57x56x2x32xf32> 384 %pack = tensor.pack %transpose padding_value(%padding : f32) 385 outer_dims_perm = [0, 2, 1, 3] 386 inner_dims_pos = [3] 387 inner_tiles = [32] 388 into %1 : tensor<1x56x57x55xf32> -> tensor<1x57x56x2x32xf32> 389 return %pack : tensor<1x57x56x2x32xf32> 390} 391//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_with_padding( 392// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x55xf32>, %[[PADDING:.+]]: f32) 393// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32> 394// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32) 395// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3] 396// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32] 397// CHECK-SAME: into %[[INIT]] 398// CHECK: return %[[PACK]] 399 400// ----- 401 402func.func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x56x57x2x32xf32> { 403 %0 = tensor.empty() : tensor<1x56x57x64xf32> 404 %transposed = linalg.transpose 405 ins(%arg0 : tensor<56x57x1x64xf32>) 406 outs(%0 : tensor<1x56x57x64xf32>) 407 permutation = [2, 0, 1, 3] 408 409 %1 = tensor.empty() : tensor<1x56x57x2x32xf32> 410 %pack = tensor.pack %transposed 411 inner_dims_pos = [3] 412 inner_tiles = [32] 413 into %1 : tensor<1x56x57x64xf32> -> tensor<1x56x57x2x32xf32> 414 return %pack : tensor<1x56x57x2x32xf32> 415} 416//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm( 417// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>) 418// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x56x57x2x32xf32> 419// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 420// CHECK-SAME: outer_dims_perm = [2, 0, 1, 3] 421// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32] 422// CHECK-SAME: into %[[INIT]] 423// CHECK: return %[[PACK]] 424 425// ----- 426 427func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change(%arg0: tensor<25x30x35x40xf32>, %transpose_dest: tensor<35x40x25x30xf32>, %pack_dest: tensor<3x35x5x8x5x10x5xf32>) -> tensor<3x35x5x8x5x10x5xf32> { 428 %transposed = linalg.transpose 429 ins(%arg0 : tensor<25x30x35x40xf32>) 430 outs(%transpose_dest : tensor<35x40x25x30xf32>) 431 permutation = [2, 3, 0, 1] 432 433 %pack = tensor.pack %transposed 434 outer_dims_perm = [3, 0, 2, 1] 435 inner_dims_pos = [1, 3, 2] 436 inner_tiles = [5, 10, 5] 437 into %pack_dest : tensor<35x40x25x30xf32> -> tensor<3x35x5x8x5x10x5xf32> 438 return %pack : tensor<3x35x5x8x5x10x5xf32> 439} 440//CHECK-LABEL: func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change( 441// CHECK-SAME: %[[ARG0:.+]]: tensor<25x30x35x40xf32>, 442// CHECK-SAME: %[[ARG1:.+]]: tensor<35x40x25x30xf32>, 443// CHECK-SAME: %[[ARG2:.+]]: tensor<3x35x5x8x5x10x5xf32>) -> tensor<3x35x5x8x5x10x5xf32> { 444// CHECK: %[[VAL0:.+]] = tensor.empty() : tensor<3x35x5x8x5x10x5xf32> 445// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 446// CHECK-SAME: outer_dims_perm = [1, 2, 0, 3] 447// CHECK-SAME: inner_dims_pos = [3, 1, 0] 448// CHECK-SAME: inner_tiles = [5, 10, 5] 449// CHECK-SAME: into %[[VAL0]] 450// CHECK: return %[[PACK]] 451 452// ----- 453 454func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %pack_dest: tensor<?x?x?x?x?x?x?xf32>, %tile_p : index, %tile_q : index, %tile_r : index) -> tensor<?x?x?x?x?x?x?xf32> { 455 %transposed = linalg.transpose 456 ins(%arg0 : tensor<?x?x?x?xf32>) 457 outs(%transpose_dest : tensor<?x?x?x?xf32>) 458 permutation = [2, 3, 0, 1] 459 460 %pack = tensor.pack %transposed 461 outer_dims_perm = [3, 0, 2, 1] 462 inner_dims_pos = [1, 3, 2] 463 inner_tiles = [%tile_p, %tile_q, %tile_r] 464 into %pack_dest : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32> 465 return %pack : tensor<?x?x?x?x?x?x?xf32> 466} 467// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)> 468//CHECK-LABEL: func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes( 469// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, 470// CHECK-SAME: %[[ARG2:.+]]: tensor<?x?x?x?x?x?x?xf32>, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index) -> tensor<?x?x?x?x?x?x?xf32> { 471// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 472// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 473// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index 474// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index 475// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32> 476// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32> 477// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32> 478// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?xf32> 479// CHECK: %[[VAL0:.+]] = affine.apply #[[$MAP]]()[%[[DIM2]], %[[ARG3]]] 480// CHECK: %[[VAL1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[ARG4]]] 481// CHECK: %[[VAL2:.+]] = affine.apply #[[$MAP]]()[%[[DIM]], %[[ARG5]]] 482// CHECK: %[[VAL3:.+]] = tensor.empty(%[[VAL1]], %[[DIM1]], %[[VAL2]], %[[VAL0]], %[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?x?x?x?x?xf32> 483// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [1, 2, 0, 3] inner_dims_pos = [3, 1, 0] inner_tiles = [%[[ARG3]], %[[ARG4]], %[[ARG5]]] into %[[VAL3]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32> 484// CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32> 485 486// ----- 487 488func.func @linalg_transpose_tensor_pack_multiple_tiles(%arg0: tensor<?x32x128xbf16>) -> tensor<32x?x64x16x2xbf16> { 489 %c0 = arith.constant 0 : index 490 %cst = arith.constant 0.000000e+00 : bf16 491 %dim = tensor.dim %arg0, %c0 : tensor<?x32x128xbf16> 492 493 %0 = tensor.empty(%dim) : tensor<32x128x?xbf16> 494 %transposed = linalg.transpose 495 ins(%arg0 : tensor<?x32x128xbf16>) 496 outs(%0 : tensor<32x128x?xbf16>) 497 permutation = [1, 2, 0] 498 499 %2 = tensor.empty(%dim) : tensor<32x?x64x16x2xbf16> 500 %pack = tensor.pack %transposed 501 padding_value(%cst : bf16) 502 outer_dims_perm = [0, 2, 1] 503 inner_dims_pos = [2, 1] 504 inner_tiles = [16, 2] 505 into %2 : tensor<32x128x?xbf16> -> tensor<32x?x64x16x2xbf16> 506 return %pack : tensor<32x?x64x16x2xbf16> 507} 508// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> 509//CHECK-LABEL: func.func @linalg_transpose_tensor_pack_multiple_tiles( 510// CHECK-SAME: %[[ARG0:.+]]: tensor<?x32x128xbf16>) -> tensor<32x?x64x16x2xbf16> { 511// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 512// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : bf16 513// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x128xbf16> 514// CHECK: %[[VAL0:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]] 515// CHECK: %[[VAL1:.+]] = tensor.empty(%[[VAL0]]) : tensor<32x?x64x16x2xbf16> 516// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 517// CHECK-SAME: padding_value(%[[CST]] : bf16) 518// CHECK-SAME: outer_dims_perm = [1, 0, 2] 519// CHECK-SAME: inner_dims_pos = [0, 2] 520// CHECK-SAME: inner_tiles = [16, 2] 521// CHECK-SAME: into %[[VAL1]] : tensor<?x32x128xbf16> -> tensor<32x?x64x16x2xbf16> 522// CHECK: return %[[PACK]] : tensor<32x?x64x16x2xbf16> 523// CHECK: } 524 525// ----- 526 527func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> tensor<16x4xi32> { 528 %0 = tensor.empty() : tensor<1x1x16x4xi32> 529 %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>) 530 outs(%0 : tensor<1x1x16x4xi32>) 531 permutation = [1, 0, 3, 2] 532 %1 = tensor.empty() : tensor<16x4xi32> 533 %unpack = tensor.unpack %transposed 534 outer_dims_perm = [0, 1] 535 inner_dims_pos = [0, 1] 536 inner_tiles = [16, 4] into 537 %1 : tensor<1x1x16x4xi32> -> tensor<16x4xi32> 538 return %unpack : tensor<16x4xi32> 539} 540//CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold( 541// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<16x4xi32> { 542// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<16x4xi32> 543// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 544// CHECK-SAME: outer_dims_perm = [1, 0] 545// CHECK-SAME: inner_dims_pos = [1, 0] 546// CHECK-SAME: inner_tiles = [4, 16] 547// CHECK-SAME: into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<16x4xi32> 548// CHECK: return %[[UNPACK]] : tensor<16x4xi32> 549// CHECK: } 550 551// ----- 552 553func.func @linalg_transpose_tensor_unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> { 554 %0 = tensor.empty() : tensor<1x1x16x4xi32> 555 %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>) 556 outs(%0 : tensor<1x1x16x4xi32>) 557 permutation = [1, 0, 3, 2] 558 %1 = tensor.empty() : tensor<15x3xi32> 559 %unpack = tensor.unpack %transposed 560 outer_dims_perm = [0, 1] 561 inner_dims_pos = [0, 1] 562 inner_tiles = [16, 4] into 563 %1 : tensor<1x1x16x4xi32> -> tensor<15x3xi32> 564 return %unpack : tensor<15x3xi32> 565} 566//CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_partial_tile( 567// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> { 568// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32> 569// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 570// CHECK-SAME: outer_dims_perm = [1, 0] 571// CHECK-SAME: inner_dims_pos = [1, 0] 572// CHECK-SAME: inner_tiles = [4, 16] 573// CHECK-SAME: into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32> 574// CHECK: return %[[UNPACK]] : tensor<15x3xi32> 575// CHECK: } 576 577// ----- 578 579func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %unpack_dest: tensor<?x?xf32>, %tile_p : index, %tile_q : index) -> tensor<?x?xf32> { 580 %transposed = linalg.transpose 581 ins(%arg0 : tensor<?x?x?x?xf32>) 582 outs(%transpose_dest : tensor<?x?x?x?xf32>) 583 permutation = [1, 0, 3, 2] 584 585 %unpack = tensor.unpack %transposed 586 outer_dims_perm = [1, 0] 587 inner_dims_pos = [0, 1] 588 inner_tiles = [%tile_p, %tile_q] 589 into %unpack_dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32> 590 return %unpack : tensor<?x?xf32> 591} 592// CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes( 593// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>, 594// CHECK-SAME: %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> { 595// CHECK-DAG: %[[CST1:.+]] = arith.constant 1 : index 596// CHECK-DAG: %[[CST0:.+]] = arith.constant 0 : index 597// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32> 598// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32> 599// CHECK: %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32> 600// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 601// CHECK-SAME: outer_dims_perm = [0, 1] 602// CHECK-SAME: inner_dims_pos = [1, 0] 603// CHECK-SAME: inner_tiles = [%[[IDX2]], %[[IDX1]]] 604// CHECK-SAME: into %[[OUT]] : tensor<?x?x?x?xf32> -> tensor<?x?xf32> 605// CHECK: return %[[UNPACK]] : tensor<?x?xf32> 606// CHECK: } 607 608// ----- 609 610func.func @tensor_unpack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<3648x56xf32> { 611 %0 = tensor.empty() : tensor<56x3648xf32> 612 %pack = tensor.unpack %arg0 613 outer_dims_perm = [0, 1] 614 inner_dims_pos = [0, 1] 615 inner_tiles = [1, 64] 616 into %0 : tensor<56x57x1x64xf32> -> tensor<56x3648xf32> 617 618 %1 = tensor.empty() : tensor<3648x56xf32> 619 %transposed = linalg.transpose 620 ins(%pack : tensor<56x3648xf32>) 621 outs(%1 : tensor<3648x56xf32>) 622 permutation = [1,0] 623 return %transposed : tensor<3648x56xf32> 624} 625// CHECK-LABEL: func.func @tensor_unpack_linalg_transpose_fold( 626// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>) -> tensor<3648x56xf32> { 627// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x56xf32> 628// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 629// CHECK-SAME: outer_dims_perm = [1, 0] 630// CHECK-SAME: inner_dims_pos = [1, 0] 631// CHECK-SAME: inner_tiles = [1, 64] 632// CHECK-SAME: into %[[OUT:.+]] : tensor<56x57x1x64xf32> -> tensor<3648x56xf32> 633// CHECK: return %[[UNPACK]] : tensor<3648x56xf32> 634// CHECK: } 635 636// ----- 637 638func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16xf32>) -> tensor<100x71x64xf32> { 639 %0 = tensor.empty() : tensor<71x100x64xf32> 640 %pack = tensor.unpack %arg0 641 inner_dims_pos = [1, 2] 642 inner_tiles = [16, 16] 643 into %0 : tensor<71x7x4x16x16xf32> -> tensor<71x100x64xf32> 644 645 %1 = tensor.empty() : tensor<100x71x64xf32> 646 %transposed = linalg.transpose 647 ins(%pack : tensor<71x100x64xf32>) 648 outs(%1 : tensor<100x71x64xf32>) 649 permutation = [1, 0, 2] 650 return %transposed : tensor<100x71x64xf32> 651} 652// CHECK-LABEL: func.func @tensor_padded_unpack_linalg_transpose_fold( 653// CHECK-SAME: %[[ARG0:.+]]: tensor<71x7x4x16x16xf32>) -> tensor<100x71x64xf32> { 654// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<100x71x64xf32> 655// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 656// CHECK-SAME: outer_dims_perm = [1, 0, 2] 657// CHECK-SAME: inner_dims_pos = [0, 2] 658// CHECK-SAME: inner_tiles = [16, 16] 659// CHECK-SAME: into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32> 660// CHECK: return %[[UNPACK]] : tensor<100x71x64xf32> 661// CHECK: } 662 663// ----- 664 665func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> { 666 %0 = tensor.empty() : tensor<5x2x3x16x4xi32> 667 %transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>) 668 outs(%0 : tensor<5x2x3x16x4xi32>) 669 permutation = [2, 0, 1, 4, 3] 670 %1 = tensor.empty() : tensor<5x48x8xi32> 671 %unpack = tensor.unpack %transposed 672 outer_dims_perm = [0, 2, 1] 673 inner_dims_pos = [1, 2] 674 inner_tiles = [16, 4] into 675 %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32> 676 return %unpack : tensor<5x48x8xi32> 677} 678//CHECK-LABEL: func.func @non_involution_transpose_unpack_fold( 679// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> { 680// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32> 681// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 682// CHECK-SAME: outer_dims_perm = [2, 1, 0] 683// CHECK-SAME: inner_dims_pos = [2, 1] 684// CHECK-SAME: inner_tiles = [4, 16] 685// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32> 686// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32> 687// CHECK: } 688 689// ----- 690 691func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> { 692 %0 = tensor.empty() : tensor<3x56x3648xf32> 693 %unpack = tensor.unpack %arg0 694 outer_dims_perm = [2, 0, 1] 695 inner_dims_pos = [1, 2] 696 inner_tiles = [1, 64] 697 into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32> 698 699 %1 = tensor.empty() : tensor<3648x3x56xf32> 700 %transposed = linalg.transpose 701 ins(%unpack : tensor<3x56x3648xf32>) 702 outs(%1 : tensor<3648x3x56xf32>) 703 permutation = [2, 0, 1] 704 return %transposed : tensor<3648x3x56xf32> 705} 706// CHECK-LABEL: func.func @unpack_non_involution_transpose_fold( 707// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> { 708// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32> 709// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 710// CHECK-SAME: outer_dims_perm = [0, 1, 2] 711// CHECK-SAME: inner_dims_pos = [2, 0] 712// CHECK-SAME: inner_tiles = [1, 64] 713// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32> 714// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32> 715// CHECK: } 716 717// ----- 718 719func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> { 720 %0 = tensor.empty() : tensor<5x2x3x16x4xi32> 721 %transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>) 722 outs(%0 : tensor<5x2x3x16x4xi32>) 723 permutation = [2, 0, 4, 1, 3] 724 %1 = tensor.empty() : tensor<5x32x12xi32> 725 %unpack = tensor.unpack %transposed 726 inner_dims_pos = [1, 2] 727 inner_tiles = [16, 4] into 728 %1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32> 729 return %unpack : tensor<5x32x12xi32> 730} 731//CHECK-LABEL: func.func @transpose_unpacked_dims_no_fold( 732// CHECK: linalg.transpose 733// CHECK: tensor.unpack 734 735// ----- 736 737#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)> 738#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)> 739func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> { 740 %0 = tensor.empty() : tensor<5x2x3x16x4xi32> 741 %transposed = linalg.generic { 742 iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], 743 indexing_maps = [#map, #map1]} 744 ins(%arg0 : tensor<2x3x5x4x16xi32>) 745 outs(%0 : tensor<5x2x3x16x4xi32>) { 746 ^bb0(%in : i32, %out : i32): 747 linalg.yield %in : i32 748 } -> tensor<5x2x3x16x4xi32> 749 %1 = tensor.empty() : tensor<5x48x8xi32> 750 %unpack = tensor.unpack %transposed 751 outer_dims_perm = [0, 2, 1] 752 inner_dims_pos = [1, 2] 753 inner_tiles = [16, 4] into 754 %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32> 755 return %unpack : tensor<5x48x8xi32> 756} 757//CHECK-LABEL: func.func @generic_transpose_unpack_fold( 758// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> { 759// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32> 760// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 761// CHECK-SAME: outer_dims_perm = [2, 1, 0] 762// CHECK-SAME: inner_dims_pos = [2, 1] 763// CHECK-SAME: inner_tiles = [4, 16] 764// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32> 765// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32> 766// CHECK: } 767 768// ----- 769 770#map = affine_map<(d0, d1, d2)->(d1, d2, d0)> 771#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)> 772func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> { 773 %0 = tensor.empty() : tensor<3x56x3648xf32> 774 %unpack = tensor.unpack %arg0 775 outer_dims_perm = [2, 0, 1] 776 inner_dims_pos = [1, 2] 777 inner_tiles = [1, 64] 778 into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32> 779 780 %1 = tensor.empty() : tensor<3648x3x56xf32> 781 %transposed = linalg.generic { 782 iterator_types = ["parallel", "parallel", "parallel"], 783 indexing_maps = [#map, #map1]} 784 ins(%unpack : tensor<3x56x3648xf32>) 785 outs(%1 : tensor<3648x3x56xf32>) { 786 ^bb0(%in : f32, %out : f32): 787 linalg.yield %in : f32 788 } -> tensor<3648x3x56xf32> 789 return %transposed : tensor<3648x3x56xf32> 790} 791// CHECK-LABEL: func.func @unpack_generic_transpose_fold( 792// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> { 793// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32> 794// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 795// CHECK-SAME: outer_dims_perm = [0, 1, 2] 796// CHECK-SAME: inner_dims_pos = [2, 0] 797// CHECK-SAME: inner_tiles = [1, 64] 798// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32> 799// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32> 800// CHECK: } 801