1// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s 2 3// CHECK-LABEL: func.func @single_dim_packing( 4// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>) 5// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32> 6// CHECK: return %[[EXPANDED]] : tensor<8x32xf32> 7func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> { 8 %empty = tensor.empty() : tensor<8x32xf32> 9 %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32> 10 return %0 : tensor<8x32xf32> 11} 12 13// ----- 14 15// CHECK-LABEL: func.func @single_dim_packing_with_padding( 16// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>) 17// CHECK-NOT: tensor.expand_shape 18// CHECK: tensor.pack 19func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> { 20 %empty = tensor.empty() : tensor<8x32xf32> 21 %cst = arith.constant 0.000000e+00 : f32 22 %0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32> 23 return %0 : tensor<8x32xf32> 24} 25 26// ----- 27 28// CHECK-LABEL: func.func @single_last_inner_dim_packing( 29// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>) 30// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32> 31// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32> 32func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> { 33 %empty = tensor.empty() : tensor<5x8x32xf32> 34 %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32> 35 return %0 : tensor<5x8x32xf32> 36} 37 38// ----- 39 40// CHECK-LABEL: func.func @pack_1d_with_outer_dims_perm( 41// CHECK-SAME: %[[ARG0:.+]]: tensor<64xf32>) 42// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [2, 32] : tensor<64xf32> into tensor<2x32xf32> 43// CHECK: return %[[EXPANDED]] : tensor<2x32xf32> 44func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf32> { 45 %empty = tensor.empty() : tensor<2x32xf32> 46 %pack = tensor.pack %arg0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<64xf32> -> tensor<2x32xf32> 47 return %pack : tensor<2x32xf32> 48} 49 50// ----- 51 52// CHECK-LABEL: func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm( 53// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>) 54// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32> 55// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32> 56func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> { 57 %empty = tensor.empty() : tensor<5x8x32xf32> 58 %0 = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32> 59 return %0 : tensor<5x8x32xf32> 60} 61 62// ----- 63 64// CHECK-LABEL: func.func @packing_with_outer_dims_perm( 65// CHECK-NOT: tensor.expand_shape 66// CHECK: tensor.pack 67func.func @packing_with_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<8x5x32xf32> { 68 %empty = tensor.empty() : tensor<8x5x32xf32> 69 %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<8x5x32xf32> 70 return %0 : tensor<8x5x32xf32> 71} 72 73// ----- 74 75// CHECK-LABEL: func.func @single_first_inner_dim_packing( 76// CHECK-NOT: tensor.expand_shape 77// CHECK: tensor.pack 78func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x5x32xf32> { 79 %empty = tensor.empty() : tensor<8x5x32xf32> 80 %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32> 81 return %0 : tensor<8x5x32xf32> 82} 83 84// ----- 85 86// CHECK-LABEL: func.func @pack_1x32_to_1x32x1x1 87// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] 88// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 32, 1, 1] 89// CHECK: return %[[EXPANDED]] 90func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf32> { 91 %empty = tensor.empty() : tensor<1x32x1x1xf32> 92 %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %empty 93 : tensor<1x32xf32> -> tensor<1x32x1x1xf32> 94 return %pack : tensor<1x32x1x1xf32> 95} 96 97// ----- 98 99// CHECK-LABEL: func.func @pack_1x32_to_1x16x1x2 100// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] 101// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 16, 1, 2] 102// CHECK: return %[[EXPANDED]] 103func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf32> { 104 %empty = tensor.empty() : tensor<1x16x1x2xf32> 105 %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 2] into %empty 106 : tensor<1x32xf32> -> tensor<1x16x1x2xf32> 107 return %pack : tensor<1x16x1x2xf32> 108} 109 110// ----- 111 112// CHECK-LABEL: func.func @pack_32x1_to_16x1x2x1 113// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] 114// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 16, 2, 1] 115// CHECK: return %[[EXPANDED]] 116func.func @pack_32x1_to_16x1x2x1(%arg0 : tensor<32x1xf32>) -> tensor<1x16x2x1xf32> { 117 %empty = tensor.empty() : tensor<1x16x2x1xf32> 118 %pack = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 1] into %empty 119 : tensor<32x1xf32> -> tensor<1x16x2x1xf32> 120 return %pack : tensor<1x16x2x1xf32> 121} 122 123// ----- 124 125// CHECK-LABEL: func.func @pack_32x1_to_16x1x1x2 126// CHECK-NOT: tensor.expand_shape 127// CHECK: tensor.pack 128func.func @pack_32x1_to_16x1x1x2(%arg0 : tensor<32x1xf32>) -> tensor<16x1x1x2xf32> { 129 %empty = tensor.empty() : tensor<16x1x1x2xf32> 130 %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [1, 2] into %empty 131 : tensor<32x1xf32> -> tensor<16x1x1x2xf32> 132 return %pack : tensor<16x1x1x2xf32> 133} 134 135// ----- 136 137// CHECK-LABEL: func.func @unpack_1d_to_collapse 138// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32xf32>) 139// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32> 140// CHECK: return %[[COLLAPSED]] 141func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> { 142 %empty = tensor.empty() : tensor<256xf32> 143 %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32> 144 return %0 : tensor<256xf32> 145} 146 147// ----- 148 149// CHECK-LABEL: func.func @unpack_to_partial_slice 150// CHECK-NOT: tensor.collapse 151// CHECK: tensor.unpack 152func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> { 153 %empty = tensor.empty() : tensor<255xf32> 154 %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32> 155 return %0 : tensor<255xf32> 156} 157 158// ----- 159 160// CHECK-LABEL: func.func @unpack_dynamic 161// CHECK-NOT: tensor.collapse 162// CHECK: tensor.unpack 163func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> { 164 %c32 = arith.constant 32 : index 165 %c0 = arith.constant 0 : index 166 %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32> 167 %size = arith.muli %d0, %c32 : index 168 %empty = tensor.empty(%size) : tensor<?xf32> 169 %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32> 170 return %0 : tensor<?xf32> 171} 172 173// ----- 174 175// CHECK-LABEL: func.func @single_last_inner_dim_unpacking( 176// CHECK-SAME: %[[ARG0:.+]]: tensor<5x8x32xf32>) 177// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32> 178// CHECK: return %[[COLLAPSED]] : tensor<5x256xf32> 179func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> { 180 %empty = tensor.empty() : tensor<5x256xf32> 181 %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32> 182 return %0 : tensor<5x256xf32> 183} 184 185// ----- 186 187// CHECK-LABEL: func.func @single_last_inner_dim_unpacking_with_identity_outer_dims_perm( 188// CHECK-SAME: %[[ARG0:.+]]: tensor<5x8x32xf32>) 189// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32> 190// CHECK: return %[[COLLAPSED]] : tensor<5x256xf32> 191func.func @single_last_inner_dim_unpacking_with_identity_outer_dims_perm(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> { 192 %empty = tensor.empty() : tensor<5x256xf32> 193 %0 = tensor.unpack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32> 194 return %0 : tensor<5x256xf32> 195} 196 197// ----- 198 199// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm( 200// CHECK-NOT: tensor.collpase_shape 201// CHECK: tensor.unpack 202func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> { 203 %empty = tensor.empty() : tensor<5x256xf32> 204 %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32> 205 return %0 : tensor<5x256xf32> 206} 207 208// ----- 209 210// CHECK-LABEL: func.func @single_first_inner_dim_unpacking( 211// CHECK-NOT: tensor.collapse_shape 212// CHECK: tensor.unpack 213func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> { 214 %empty = tensor.empty() : tensor<256x5xf32> 215 %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32> 216 return %0 : tensor<256x5xf32> 217} 218 219// ----- 220 221// CHECK-LABEL: func.func @unpack_1x32x1x1_to_1x32 222// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] 223// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] 224// CHECK: return %[[COLLAPSED]] 225func.func @unpack_1x32x1x1_to_1x32(%arg0 : tensor<1x32x1x1xf32>) -> tensor<1x32xf32> { 226 %empty = tensor.empty() : tensor<1x32xf32> 227 %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %empty 228 : tensor<1x32x1x1xf32> -> tensor<1x32xf32> 229 return %unpack : tensor<1x32xf32> 230} 231 232// ----- 233 234// CHECK-LABEL: func.func @unpack_1x2x1x16_to_1x32 235// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] 236// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] 237// CHECK: return %[[COLLAPSED]] 238func.func @unpack_1x2x1x16_to_1x32(%arg0 : tensor<1x2x1x16xf32>) -> tensor<1x32xf32> { 239 %empty = tensor.empty() : tensor<1x32xf32> 240 %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [1, 16] into %empty 241 : tensor<1x2x1x16xf32> -> tensor<1x32xf32> 242 return %unpack : tensor<1x32xf32> 243} 244 245// ----- 246 247// CHECK-LABEL: func.func @unpack_16x1x2x1_to_32x1 248// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] 249// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] 250// CHECK: return %[[COLLAPSED]] 251func.func @unpack_16x1x2x1_to_32x1(%arg0 : tensor<1x16x2x1xf32>) -> tensor<32x1xf32> { 252 %empty = tensor.empty() : tensor<32x1xf32> 253 %unpack = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 1] into %empty 254 : tensor<1x16x2x1xf32> -> tensor<32x1xf32> 255 return %unpack : tensor<32x1xf32> 256} 257 258// ----- 259 260// CHECK-LABEL: func.func @unpack_16x1x1x2_to_32x1 261// CHECK-NOT: tensor.collapse_shape 262// CHECK: tensor.unpack 263func.func @unpack_16x1x1x2_to_32x1(%arg0 : tensor<16x1x1x2xf32>) -> tensor<32x1xf32> { 264 %empty = tensor.empty() : tensor<32x1xf32> 265 %unpack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [1, 2] into %empty 266 : tensor<16x1x1x2xf32> -> tensor<32x1xf32> 267 return %unpack : tensor<32x1xf32> 268} 269 270// ----- 271 272// CHECK-LABEL: func.func @pad_like_pack( 273// CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>) 274// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32> 275// CHECK: return %[[EXPANDED]] : tensor<1x1x32x64xf32> 276func.func @pad_like_pack(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> { 277 %empty = tensor.empty() : tensor<1x1x32x64xf32> 278 %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32> 279 return %0 : tensor<1x1x32x64xf32> 280} 281 282// ----- 283 284// CHECK-LABEL: func.func @pad_like_pack_with_outer_dims_perm( 285// CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>) 286// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32> 287// CHECK: return %[[EXPANDED]] : tensor<1x1x32x64xf32> 288func.func @pad_like_pack_with_outer_dims_perm(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> { 289 %empty = tensor.empty() : tensor<1x1x32x64xf32> 290 %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32> 291 return %0 : tensor<1x1x32x64xf32> 292} 293 294// ----- 295 296// CHECK-LABEL: func.func @inner_pad_like_pack( 297// CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>) 298// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [32, 1, 64] : tensor<32x64xf32> into tensor<32x1x64xf32> 299// CHECK: return %[[EXPANDED]] : tensor<32x1x64xf32> 300func.func @inner_pad_like_pack(%arg0: tensor<32x64xf32>) -> tensor<32x1x64xf32> { 301 %empty = tensor.empty() : tensor<32x1x64xf32> 302 %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x64xf32> -> tensor<32x1x64xf32> 303 return %0 : tensor<32x1x64xf32> 304} 305 306// ----- 307 308// Do not simplify pack with inner dimension shuffling. 309// CHECK-LABEL: func.func @pad_and_inner_dim_shuffle_pack( 310// CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>) 311// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x64x32xf32> 312// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [64, 32] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<1x1x64x32xf32> 313// CHECK: return %[[PACK]] : tensor<1x1x64x32xf32> 314func.func @pad_and_inner_dim_shuffle_pack(%arg0: tensor<32x64xf32>) -> tensor<1x1x64x32xf32> { 315 %empty = tensor.empty() : tensor<1x1x64x32xf32> 316 %0 = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [64, 32] into %empty : tensor<32x64xf32> -> tensor<1x1x64x32xf32> 317 return %0 : tensor<1x1x64x32xf32> 318} 319 320// ----- 321 322// Do not simplify pack with inner dimension transpose. 323// CHECK-LABEL: func.func @pad_like_pack_with_transpose( 324// CHECK-SAME: %[[ARG0:.+]]: tensor<32x64x16xf32>) 325// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x1x16x64xf32> 326// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [64] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<32x1x16x64xf32> 327// CHECK: return %[[PACK]] : tensor<32x1x16x64xf32> 328func.func @pad_like_pack_with_transpose(%arg0: tensor<32x64x16xf32>) -> tensor<32x1x16x64xf32> { 329 %empty = tensor.empty() : tensor<32x1x16x64xf32> 330 %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x64x16xf32> -> tensor<32x1x16x64xf32> 331 return %0 : tensor<32x1x16x64xf32> 332} 333 334// ----- 335 336// CHECK-LABEL: func.func @unpad_like_unpack( 337// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>) 338// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32> 339// CHECK: return %[[COLLAPSED]] : tensor<32x64xf32> 340func.func @unpad_like_unpack(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> { 341 %empty = tensor.empty() : tensor<32x64xf32> 342 %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32> 343 return %0 : tensor<32x64xf32> 344} 345 346// ----- 347 348// CHECK-LABEL: func.func @unpad_like_unpack_with_outer_dims_perm( 349// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>) 350// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32> 351// CHECK: return %[[COLLAPSED]] : tensor<32x64xf32> 352func.func @unpad_like_unpack_with_outer_dims_perm(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> { 353 %empty = tensor.empty() : tensor<32x64xf32> 354 %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32> 355 return %0 : tensor<32x64xf32> 356} 357 358// ----- 359 360// CHECK-LABEL: func.func @inner_unpad_like_unpack( 361// CHECK-SAME: %[[ARG0:.+]]: tensor<32x1x64xf32>) 362// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<32x1x64xf32> into tensor<32x64xf32> 363// CHECK: return %[[COLLAPSED]] : tensor<32x64xf32> 364func.func @inner_unpad_like_unpack(%arg0: tensor<32x1x64xf32>) -> tensor<32x64xf32> { 365 %empty = tensor.empty() : tensor<32x64xf32> 366 %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x1x64xf32> -> tensor<32x64xf32> 367 return %0 : tensor<32x64xf32> 368} 369 370// ----- 371 372// Do not simplify unpack with inner dimension shuffling. 373// CHECK-LABEL: func.func @unpad_and_inner_dim_shuffle_pack( 374// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>) 375// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64x32xf32> 376// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [32, 64] into %[[EMPTY]] : tensor<1x1x32x64xf32> -> tensor<64x32xf32> 377// CHECK: return %[[UNPACK]] : tensor<64x32xf32> 378func.func @unpad_and_inner_dim_shuffle_pack(%arg0: tensor<1x1x32x64xf32>) -> tensor<64x32xf32> { 379 %empty = tensor.empty() : tensor<64x32xf32> 380 %0 = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<64x32xf32> 381 return %0 : tensor<64x32xf32> 382} 383 384// ----- 385 386// Do not simplify unpack with inner dimension transpose. 387// CHECK-LABEL: func.func @unpad_like_unpack_with_transpose( 388// CHECK-SAME: %[[ARG0:.+]]: tensor<32x1x16x64xf32>) 389// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x64x16xf32> 390// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [64] into %[[EMPTY]] : tensor<32x1x16x64xf32> -> tensor<32x64x16xf32> 391// CHECK: return %[[UNPACK]] : tensor<32x64x16xf32> 392func.func @unpad_like_unpack_with_transpose(%arg0: tensor<32x1x16x64xf32>) -> tensor<32x64x16xf32> { 393 %empty = tensor.empty() : tensor<32x64x16xf32> 394 %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x1x16x64xf32> -> tensor<32x64x16xf32> 395 return %0 : tensor<32x64x16xf32> 396} 397