1// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s 2 3#map0 = affine_map<(d0, d1) -> (d0, d1)> 4func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32> 5{ 6 %c0 = arith.constant 0 : index 7 %c1 = arith.constant 1 : index 8 %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 9 %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 10 %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 11 %3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} 12 ins(%arg0 : tensor<?x?xf32>) 13 outs(%2 : tensor<?x?xf32>) { 14 ^bb0(%arg3: f32, %arg4: f32): 15 %4 = arith.addf %arg3, %arg3 : f32 16 linalg.yield %4 : f32 17 } -> tensor<?x?xf32> 18 %4 = tensor.pack %3 19 inner_dims_pos = [0, 1] 20 inner_tiles = [8, 2] 21 into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 22 return %4 : tensor<?x?x8x2xf32> 23} 24// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> 25// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> 26// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 27// CHECK-LABEL: func.func @dynamic_elem_pack 28// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 29// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 30// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 31// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 32// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] 33// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] 34// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[$MAP0]]()[%[[D0]]] 35// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[$MAP1]]()[%[[D1]]] 36// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x2xf32> 37// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] 38// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 2] 39// CHECK-SAME: into %[[ARG0_EMPTY]] 40// CHECK: %[[ELEM:.+]] = linalg.generic 41// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP2]]] 42// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] 43// CHECK-SAME: ins(%[[PACK_ARG0]] 44// CHECK-SAME: outs(%[[DEST]] 45// CHECK: return %[[ELEM]] : tensor<?x?x8x2xf32> 46 47// ----- 48 49#map0 = affine_map<(d0, d1) -> (d0, d1)> 50func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ 51 %init = tensor.empty() : tensor<128x256xi32> 52 %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} 53 ins(%arg0 : tensor<128x256xi32>) 54 outs(%init : tensor<128x256xi32>) { 55 ^bb0(%arg3: i32, %arg4: i32): 56 %4 = arith.addi %arg3, %arg3 : i32 57 linalg.yield %4 : i32 58 } -> tensor<128x256xi32> 59 %pack = tensor.pack %elem 60 inner_dims_pos = [1, 0] 61 inner_tiles = [16, 32] 62 into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> 63 return %pack : tensor<4x16x16x32xi32> 64} 65// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 66// CHECK-LABEL: func.func @elem_pack_transpose_inner_dims 67// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 68// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 69// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> 70// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] 71// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] 72// CHECK-SAME: into %[[ARG0_EMPTY]] 73// CHECK: %[[ELEM:.+]] = linalg.generic 74// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] 75// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] 76// CHECK-SAME: ins(%[[PACK_ARG0]] 77// CHECK-SAME: outs(%[[DEST]] 78// CHECK: return %[[ELEM]] : tensor<4x16x16x32xi32> 79 80// ----- 81 82#map0 = affine_map<(d0, d1) -> (d0, d1)> 83func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{ 84 %init = tensor.empty() : tensor<128x256xi32> 85 %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} 86 ins(%arg0 : tensor<128x256xi32>) 87 outs(%init : tensor<128x256xi32>) { 88 ^bb0(%arg3: i32, %arg4: i32): 89 %4 = arith.addi %arg3, %arg3 : i32 90 linalg.yield %4 : i32 91 } -> tensor<128x256xi32> 92 %pack = tensor.pack %elem 93 outer_dims_perm = [1, 0] 94 inner_dims_pos = [0, 1] 95 inner_tiles = [32, 16] 96 into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> 97 return %pack : tensor<16x4x32x16xi32> 98} 99// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 100// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims 101// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 102// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 103// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> 104// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] 105// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] 106// CHECK-SAME: into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32> 107// CHECK: %[[ELEM:.+]] = linalg.generic 108// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] 109// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] 110// CHECK-SAME: ins(%[[PACK_ARG0]] 111// CHECK-SAME: outs(%[[DEST]] 112// CHECK: return %[[ELEM]] : tensor<16x4x32x16xi32> 113 114// ----- 115 116#map0 = affine_map<(d0, d1) -> (d0, d1)> 117func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{ 118 %init = tensor.empty() : tensor<128x256xi32> 119 %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} 120 ins(%arg0 : tensor<128x256xi32>) 121 outs(%init : tensor<128x256xi32>) { 122 ^bb0(%arg3: i32, %arg4: i32): 123 %4 = arith.addi %arg3, %arg3 : i32 124 linalg.yield %4 : i32 125 } -> tensor<128x256xi32> 126 %pack = tensor.pack %elem 127 outer_dims_perm = [1, 0] 128 inner_dims_pos = [1, 0] 129 inner_tiles = [16, 32] 130 into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> 131 return %pack : tensor<16x4x16x32xi32> 132} 133// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 134// CHECK-LABEL: func.func @elem_pack_transpose_inner_and_outer_dims 135// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 136// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 137// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32> 138// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] 139// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] 140// CHECK-SAME: into %[[ARG0_EMPTY]] 141// CHECK: %[[ELEM:.+]] = linalg.generic 142// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] 143// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] 144// CHECK-SAME: ins(%[[PACK_ARG0]] 145// CHECK-SAME: outs(%[[DEST]] 146// CHECK: return %[[ELEM]] : tensor<16x4x16x32xi32> 147 148// ----- 149 150#map0 = affine_map<(d0, d1) -> (d0, d1)> 151#map1 = affine_map<(d0, d1) -> (d0)> 152#map2 = affine_map<(d0, d1) -> (d1)> 153func.func @dynamic_broadcast_pack(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32> 154{ 155 %c0 = arith.constant 0 : index 156 %0 = tensor.dim %arg0, %c0 : tensor<?xf32> 157 %1 = tensor.dim %arg1, %c0 : tensor<?xf32> 158 %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 159 %3 = linalg.generic {indexing_maps = [#map1, #map2, #map0], iterator_types = ["parallel", "parallel"]} 160 ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) 161 outs(%2 : tensor<?x?xf32>) { 162 ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 163 %4 = arith.addf %arg3, %arg4 : f32 164 linalg.yield %4 : f32 165 } -> tensor<?x?xf32> 166 %4 = tensor.pack %3 167 inner_dims_pos = [0, 1] 168 inner_tiles = [8, 2] 169 into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 170 return %4 : tensor<?x?x8x2xf32> 171} 172// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> 173// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> 174// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> 175// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> 176// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 177// CHECK-LABEL: func.func @dynamic_broadcast_pack 178// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 179// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 180// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 181// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 182// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] 183// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[$MAP0]]()[%[[D0]]] 184// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor<?x8xf32> 185// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] 186// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] 187// CHECK-SAME: into %[[ARG0_EMPTY]] 188// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]] 189// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[$MAP1]]()[%[[D1]]] 190// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor<?x2xf32> 191// CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]] 192// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [2] 193// CHECK-SAME: into %[[ARG1_EMPTY]] 194// CHECK: %[[ELEM:.+]] = linalg.generic 195// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP4]]] 196// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] 197// CHECK-SAME: ins(%[[PACK_ARG0]], %[[PACK_ARG0]] 198// CHECK-SAME: outs(%[[DEST]] 199// CHECK: return %[[ELEM]] : tensor<?x?x8x2xf32> 200 201// ----- 202 203#map = affine_map<(d0, d1, d2, d3) -> (d3)> 204#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 205func.func @elem_pack_transpose_inner_and_outer_dims2(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> { 206 %0 = tensor.empty() : tensor<1x56x57x64xf32> 207 %1 = linalg.generic { 208 indexing_maps = [#map, #map1], 209 iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 210 ins(%arg0 : tensor<64xf32>) 211 outs(%0 : tensor<1x56x57x64xf32>) { 212 ^bb0(%in: f32, %out: f32): 213 linalg.yield %in : f32 214 } -> tensor<1x56x57x64xf32> 215 %2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32> 216 return %2 : tensor<1x2x56x57x32xf32> 217} 218// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)> 219// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 220// CHECK-LABEL: func.func @elem_pack_transpose_inner_and_outer_dims2 221// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 222// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 223// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32> 224// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] 225// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32] 226// CHECK-SAME: into %[[ARG0_EMPTY]] 227// CHECK: %[[RES:.+]] = linalg.generic 228// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] 229// CHECK-SAME: ins(%[[PACKED_ARG0]] 230// CHECK-SAME: outs(%[[DEST]] 231 232// ----- 233 234func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32> 235{ 236 %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> 237 %transpose = linalg.generic { 238 indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, 239 affine_map<(d0, d1, d2, d3) -> (d0)>, 240 affine_map<(d0, d1, d2, d3) -> (d1)>, 241 affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], 242 iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 243 ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>) 244 outs(%init_transpose : tensor<100x200x128x256xi32>) { 245 ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): 246 %0 = arith.addi %b0, %b1 : i32 247 %1 = arith.addi %0, %b2 : i32 248 linalg.yield %1 : i32 249 } -> tensor<100x200x128x256xi32> 250 %4 = tensor.pack %transpose 251 inner_dims_pos = [3, 2] 252 inner_tiles = [16, 32] 253 into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32> 254 return %4 : tensor<100x200x4x16x16x32xi32> 255} 256// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> 257// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)> 258// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)> 259// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)> 260// CHECK-LABEL: func.func @transpose_pack 261// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 262// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 263// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] 264// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 265// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32> 266// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] 267// CHECK-SAME: inner_dims_pos = [3, 1] inner_tiles = [16, 32] 268// CHECK-SAME: into %[[ARG0_EMPTY]] 269// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32> 270// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]] 271// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32] 272// CHECK-SAME: into %[[ARG2_EMPTY]] 273// CHECK: %[[RES:.+]] = linalg.generic 274// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] 275// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] 276// CHECK-SAME: outs(%[[DEST]] 277 278// ----- 279 280func.func @affine_constant_expr_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x1x1x1xi32>, %arg2: tensor<1x128x1x1xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32> 281{ 282 %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> 283 %transpose = linalg.generic { 284 indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, 285 affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, 0)>, 286 affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, 287 affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], 288 iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 289 ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100x1x1x1xi32>, tensor<1x128x1x1xi32>) 290 outs(%init_transpose : tensor<100x200x128x256xi32>) { 291 ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): 292 %0 = arith.addi %b0, %b1 : i32 293 %1 = arith.addi %0, %b2 : i32 294 linalg.yield %1 : i32 295 } -> tensor<100x200x128x256xi32> 296 %4 = tensor.pack %transpose 297 inner_dims_pos = [3, 2] 298 inner_tiles = [16, 32] 299 into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32> 300 return %4 : tensor<100x200x4x16x16x32xi32> 301} 302// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> 303// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, 0, 0, 0)> 304// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (0, d1, 0, 0, d5)> 305// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)> 306// CHECK-LABEL: func.func @affine_constant_expr_pack 307// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 308// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 309// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] 310// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 311// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32> 312// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] 313// CHECK-SAME: inner_dims_pos = [3, 1] inner_tiles = [16, 32] 314// CHECK-SAME: into %[[ARG0_EMPTY]] 315// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<1x4x1x1x32xi32> 316// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]] 317// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [32] 318// CHECK-SAME: into %[[ARG2_EMPTY]] 319// CHECK: %[[RES:.+]] = linalg.generic 320// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] 321// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] 322// CHECK-SAME: outs(%[[DEST]] 323 324// ----- 325 326func.func @transpose_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32> 327{ 328 %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> 329 %transpose = linalg.generic { 330 indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, 331 affine_map<(d0, d1, d2, d3) -> (d0)>, 332 affine_map<(d0, d1, d2, d3) -> (d1)>, 333 affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], 334 iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 335 ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>) 336 outs(%init_transpose : tensor<100x200x128x256xi32>) { 337 ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): 338 %0 = arith.addi %b0, %b1 : i32 339 %1 = arith.addi %0, %b2 : i32 340 linalg.yield %1 : i32 341 } -> tensor<100x200x128x256xi32> 342 %4 = tensor.pack %transpose 343 outer_dims_perm = [1, 2, 3, 0] 344 inner_dims_pos = [3, 2] 345 inner_tiles = [16, 32] 346 into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32> 347 return %4 : tensor<200x4x16x100x16x32xi32> 348} 349 350// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> 351// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)> 352// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)> 353// CHECK-LABEL: func.func @transpose_pack_with_outer_dims 354// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 355// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 356// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] 357// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 358// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<200x4x16x100x16x32xi32> 359// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] 360// CHECK-SAME: outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32] 361// CHECK-SAME: into %[[ARG0_EMPTY]] 362// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32> 363// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]] 364// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32] 365// CHECK-SAME: into %[[ARG2_EMPTY]] 366// CHECK: %[[RES:.+]] = linalg.generic 367// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]] 368// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] 369// CHECK-SAME: outs(%[[DEST]] 370 371// ----- 372 373#map0 = affine_map<(d0, d1) -> (d0, d1)> 374func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{ 375 %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} 376 ins(%arg0 : tensor<128x256xi32>) 377 outs(%init : tensor<128x256xi32>) { 378 ^bb0(%arg3: i32, %arg4: i32): 379 %4 = arith.addi %arg3, %arg4 : i32 380 linalg.yield %4 : i32 381 } -> tensor<128x256xi32> 382 %empty = tensor.empty() : tensor<16x4x32x16xi32> 383 %pack = tensor.pack %elem 384 outer_dims_perm = [1, 0] 385 inner_dims_pos = [0, 1] 386 inner_tiles = [32, 16] 387 into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32> 388 return %pack : tensor<16x4x32x16xi32> 389} 390 391// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 392// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims 393// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 394// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 395// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> 396// CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]] 397// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] 398// CHECK-SAME: into %[[ARG1_EMPTY]] 399// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> 400// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] 401// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] 402// CHECK-SAME: into %[[ARG0_EMPTY]] 403// CHECK: %[[RES:.+]] = linalg.generic 404// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] 405// CHECK-SAME: ins(%[[PACKED_ARG0]] 406// CHECK-SAME: outs(%[[PACKED_ARG1]] 407 408// ----- 409 410#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 411 412func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> { 413 %0 = tensor.empty() : tensor<12x56x56x64xf32> 414 %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> 415 %2 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%1 : tensor<12x56x56x64xf32>) { 416 ^bb0(%out: f32): 417 %3 = arith.addf %out, %out : f32 418 linalg.yield %3 : f32 419 } -> tensor<12x56x56x64xf32> 420 return %2 : tensor<12x56x56x64xf32> 421} 422 423// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 424// CHECK-LABEL: func.func @unpack_on_output 425// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 426// CHECK: %[[ARG0_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf32> 427// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 428// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 429// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]] 430// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> 431// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 432// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 433// CHECK-SAME: into %[[ARG0_EMPTY_PACK]] 434// CHECK: %[[RES:.+]] = linalg.generic 435// CHECK-SAME: indexing_maps = [#[[$MAP]]] 436// CHECK-SAME: outs(%[[PACKED_ARG0]] 437// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] 438// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 439// CHECK-SAME: into %[[UNPACKED_ARG0]] 440 441// ----- 442 443#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 444 445func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf32>) -> tensor<12x56x56x64xf32> { 446 %0 = tensor.empty() : tensor<12x56x56x64xf32> 447 %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> 448 %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) { 449 ^bb0(%in: f32, %out: f32): 450 %3 = arith.addf %in, %out : f32 451 linalg.yield %3 : f32 452 } -> tensor<12x56x56x64xf32> 453 return %2 : tensor<12x56x56x64xf32> 454} 455 456// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 457// CHECK-LABEL: func.func @unpack_on_input 458// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 459// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 460// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> 461// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 462// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 463// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] 464// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> 465// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] 466// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 467// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] 468// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> 469// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] 470// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 471// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] 472// CHECK: %[[RES:.+]] = linalg.generic 473// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] 474// CHECK-SAME: ins(%[[ARG0_PACK]] 475// CHECK-SAME: outs(%[[ARG1_PACK]] 476// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] 477// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 478// CHECK-SAME: into %[[ARG1]] 479 480// ----- 481 482#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 483 484func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> { 485 %0 = tensor.empty() : tensor<12x56x56x64xf32> 486 %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> 487 %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) { 488 ^bb0(%in: f32, %out: f16): 489 %3 = arith.truncf %in : f32 to f16 490 linalg.yield %3 : f16 491 } -> tensor<12x56x56x64xf16> 492 return %2 : tensor<12x56x56x64xf16> 493} 494 495// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 496// CHECK-LABEL: func.func @unpack_element_type_change 497// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 498// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 499// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> 500// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 501// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 502// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] 503// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16> 504// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] 505// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 506// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] 507// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> 508// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] 509// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 510// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] 511// CHECK: %[[RES:.+]] = linalg.generic 512// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] 513// CHECK-SAME: ins(%[[ARG0_PACK]] 514// CHECK-SAME: outs(%[[ARG1_PACK]] 515// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] 516// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 517// CHECK-SAME: into %[[ARG1]] 518 519// ----- 520 521#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 522 523func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> { 524 %init = tensor.empty() : tensor<12x56x56x64xf32> 525 %0 = tensor.empty() : tensor<12x56x56x64xf32> 526 %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> 527 %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) { 528 ^bb0(%in: f32, %out: f32): 529 %3 = arith.addf %in, %in : f32 530 linalg.yield %3 : f32 531 } -> tensor<12x56x56x64xf32> 532 return %2 : tensor<12x56x56x64xf32> 533} 534 535// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 536// CHECK-LABEL: func.func @forward_tensor_empty 537// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 538// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32> 539// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> 540// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 541// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 542// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] 543// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> 544// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> 545// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 546// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 547// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] 548// CHECK: %[[RES:.+]] = linalg.generic 549// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] 550// CHECK-SAME: ins(%[[PACKED_ARG0]] 551// CHECK-SAME: outs(%[[DEST]] 552// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]] 553// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 554// CHECK-SAME: into %[[FINAL_RES]] 555 556// ----- 557 558func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> { 559 %cst = arith.constant 0.000000e+00 : f32 560 %0 = tensor.empty() : tensor<1x56x56x64xf32> 561 %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> 562 %padded = tensor.pad %1 low[0, 1, 1, 0] high[0, 1, 1, 0] { 563 ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): 564 tensor.yield %cst : f32 565 } : tensor<1x56x56x64xf32> to tensor<1x58x58x64xf32> 566 return %padded : tensor<1x58x58x64xf32> 567} 568 569// CHECK-LABEL: func.func @pad_valid_unpack_propagation( 570// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>) 571// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 572// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0] 573// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32> 574// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] 575// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 576// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32> 577 578// ----- 579 580func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> { 581 %cst = arith.constant 0.000000e+00 : f32 582 %0 = tensor.empty() : tensor<1x56x56x64xf32> 583 %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> 584 %padded = tensor.pad %1 low[1, 1, 1, 0] high[0, 1, 1, 0] { 585 ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): 586 tensor.yield %cst : f32 587 } : tensor<1x56x56x64xf32> to tensor<2x58x58x64xf32> 588 return %padded : tensor<2x58x58x64xf32> 589} 590 591// CHECK-LABEL: func.func @pad_valid_unpack_propagation( 592// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>) 593// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 594// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0] 595// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32> 596// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] 597// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 598// CHECK-SAME: into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32> 599 600// ----- 601 602func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x66xf32> { 603 %cst = arith.constant 0.000000e+00 : f32 604 %0 = tensor.empty() : tensor<1x56x56x64xf32> 605 %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> 606 %padded = tensor.pad %1 low[0, 1, 1, 1] high[0, 1, 1, 1] { 607 ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): 608 tensor.yield %cst : f32 609 } : tensor<1x56x56x64xf32> to tensor<1x58x58x66xf32> 610 return %padded : tensor<1x58x58x66xf32> 611} 612 613// CHECK-LABEL: func.func @pad_along_unpacked_dim( 614// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>) 615// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 616// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32> 617// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 618// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 619// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> 620// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1] 621 622// ----- 623 624func.func @pad_valid_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x2x58x58x32xf32> { 625 %cst = arith.constant 0.000000e+00 : f32 626 %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] { 627 ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): 628 tensor.yield %cst : f32 629 } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32> 630 %0 = tensor.empty() : tensor<1x2x58x58x32xf32> 631 %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32> 632 return %1 : tensor<1x2x58x58x32xf32> 633} 634 635// CHECK-LABEL: func.func @pad_valid_pack_propagation( 636// CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x56x56xf32>) 637// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 638// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32> 639// CHECK: %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32] 640// CHECK-SAME: into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32> 641// CHECK: %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0] 642// CHECK: return %[[PADDED]] 643 644// ----- 645 646func.func @pad_valid_outer_dims_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x58x58x2x32xf32> { 647 %cst = arith.constant 0.000000e+00 : f32 648 %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] { 649 ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): 650 tensor.yield %cst : f32 651 } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32> 652 %0 = tensor.empty() : tensor<1x58x58x2x32xf32> 653 %1 = tensor.pack %padded outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x58x58x2x32xf32> 654 return %1 : tensor<1x58x58x2x32xf32> 655} 656 657// CHECK-LABEL: func.func @pad_valid_outer_dims_pack_propagation( 658// CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x56x56xf32>) 659// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 660// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x2x32xf32> 661// CHECK: %[[PACKED:.+]] = tensor.pack %[[ARG0]] 662// CHECK-SAME: outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [1] inner_tiles = [32] 663// CHECK-SAME: into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x56x56x2x32xf32> 664// CHECK: %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 1, 1, 0, 0] high[0, 1, 1, 0, 0] 665// CHECK: return %[[PADDED]] 666 667// ----- 668 669func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x58x32xf32> { 670 %cst = arith.constant 0.000000e+00 : f32 671 %padded = tensor.pad %arg0 low[0, 2, 1, 1] high[0, 2, 1, 1] { 672 ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): 673 tensor.yield %cst : f32 674 } : tensor<1x60x56x56xf32> to tensor<1x64x58x58xf32> 675 %0 = tensor.empty() : tensor<1x2x58x58x32xf32> 676 %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32> 677 return %1 : tensor<1x2x58x58x32xf32> 678} 679 680// CHECK-LABEL: func.func @pad_along_packed_dim( 681// CHECK-SAME: %[[ARG0:.+]]: tensor<1x60x56x56xf32>) 682// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 683// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 2, 1, 1] high[0, 2, 1, 1] 684// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x58x58x32xf32> 685// CHECK: tensor.pack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32] 686// CHECK-SAME: into %[[EMPTY]] : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32> 687 688// ----- 689 690func.func @multi_use_pad_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> (tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>) { 691 %cst = arith.constant 0.000000e+00 : f32 692 %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] { 693 ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): 694 tensor.yield %cst : f32 695 } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32> 696 %0 = tensor.empty() : tensor<1x2x58x58x32xf32> 697 %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32> 698 return %padded, %1 : tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32> 699} 700 701// CHECK-LABEL: func.func @multi_use_pad_pack_propagation( 702// CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x56x56xf32>) 703// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 704// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32> 705// CHECK: %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32] 706// CHECK-SAME: into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32> 707// CHECK: %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0] 708// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32] 709// CHECK: return %[[UNPACKED]], %[[PADDED]] 710 711// ----- 712 713#map0 = affine_map<(d0, d1) -> (d0, d1)> 714func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{ 715 %init = tensor.empty() : tensor<128x256xi32> 716 %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} 717 ins(%arg0 : tensor<128x256xi32>) 718 outs(%init : tensor<128x256xi32>) { 719 ^bb0(%arg3: i32, %arg4: i32): 720 %4 = arith.addi %arg3, %arg3 : i32 721 linalg.yield %4 : i32 722 } -> tensor<128x256xi32> 723 %dest = bufferization.alloc_tensor() : tensor<4x16x16x32xi32> 724 %pack = tensor.pack %elem 725 inner_dims_pos = [1, 0] 726 inner_tiles = [16, 32] 727 into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> 728 return %pack : tensor<4x16x16x32xi32> 729} 730 731// CHECK-LABEL: func.func @would_break_dominance( 732// CHECK-SAME: %[[ARG0:.+]]: tensor<128x256xi32>) 733// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x256xi32> 734// CHECK-NEXT: %[[GEN:.+]] = linalg.generic 735// CHECK-SAME: ins(%[[ARG0]] 736// CHECK-SAME: outs(%[[EMPTY]] 737// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32> 738// CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]] 739// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] 740// CHECK-SAME: into %[[ALLOC]] 741 742// ----- 743 744#map0 = affine_map<(d0, d1, d2, d3) -> ()> 745#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 746 747func.func @scalar_tensor(%arg0 : tensor<f32>) -> tensor<1x32x7x7x32xf32> { 748 %empty_gen = tensor.empty() : tensor<1x7x7x1024xf32> 749 %gen = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<f32>) outs(%empty_gen : tensor<1x7x7x1024xf32>) { 750 ^bb0(%in: f32, %out: f32): 751 linalg.yield %in : f32 752 } -> tensor<1x7x7x1024xf32> 753 %empty_pack = tensor.empty() : tensor<1x32x7x7x32xf32> 754 %pack = tensor.pack %gen outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %empty_pack : tensor<1x7x7x1024xf32> -> tensor<1x32x7x7x32xf32> 755 return %pack : tensor<1x32x7x7x32xf32> 756} 757 758// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> 759// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 760// CHECK-LABEL: func.func @scalar_tensor 761// CHECK-SAME: %[[ARG0:.+]]: tensor<f32>) 762// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x7x7x32xf32> 763// CHECK: linalg.generic 764// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]]] 765// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] 766// CHECK-SAME: ins(%[[ARG0]] 767// CHECK-SAME: outs(%[[EMPTY]] 768 769// ----- 770 771#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 772func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x56x56x64xf32> { 773 %init = tensor.empty() : tensor<12x56x56x64xf32> 774 %0 = tensor.empty() : tensor<12x56x56x64xf32> 775 %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] into %0 : tensor<12x64x56x56xf32> -> tensor<12x56x56x64xf32> 776 %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) { 777 ^bb0(%in: f32, %out: f32): 778 %3 = arith.addf %in, %in : f32 779 linalg.yield %3 : f32 780 } -> tensor<12x56x56x64xf32> 781 return %2 : tensor<12x56x56x64xf32> 782} 783 784// CHECK-LABEL: func.func @unpack_empty_inner_dims 785// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack 786// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 787// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 788// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 789// CHECK: %[[RES:.+]] = linalg.generic 790// CHECK-SAME: ins(%[[PACKED_ARG0]] 791// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]] 792// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 793 794// ----- 795 796#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 797#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> 798func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, 799 %arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{ 800 %elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} 801 ins(%arg0 : tensor<128x256x32xi32>) 802 outs(%arg1 : tensor<128x256xi32>) { 803 ^bb0(%arg3: i32, %arg4: i32): 804 %4 = arith.addi %arg3, %arg4 : i32 805 linalg.yield %4 : i32 806 } -> tensor<128x256xi32> 807 %dest = tensor.empty() : tensor<4x16x16x32xi32> 808 %pack = tensor.pack %elem 809 inner_dims_pos = [1, 0] 810 inner_tiles = [16, 32] 811 into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> 812 return %pack : tensor<4x16x16x32xi32> 813} 814// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 815// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> 816// CHECK-LABEL: func.func @reduction_pack_transpose_inner_dims 817// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 818// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 819// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> 820// CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]] 821// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] 822// CHECK-SAME: into %[[ARG1_EMPTY]] 823// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16x32xi32> 824// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] 825// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] 826// CHECK-SAME: into %[[ARG0_EMPTY]] 827// CHECK: %[[RED:.+]] = linalg.generic 828// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] 829// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel"] 830// CHECK-SAME: ins(%[[PACK_ARG0]] 831// CHECK-SAME: outs(%[[PACK_ARG1]] 832// CHECK: return %[[RED]] : tensor<4x16x16x32xi32> 833 834// ----- 835 836func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, 837 %arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32> 838{ 839 %reduction = linalg.generic { 840 indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, 841 affine_map<(d0, d1, d2, d3) -> (d0)>, 842 affine_map<(d0, d1, d2, d3) -> (d1)>, 843 affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>], 844 iterator_types = ["parallel", "parallel", "reduction", "parallel"]} 845 ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>) 846 outs(%init_reduction : tensor<100x128x256xi32>) { 847 ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): 848 %0 = arith.addi %b0, %b1 : i32 849 %1 = arith.addi %0, %b2 : i32 850 %2 = arith.addi %1, %b3 : i32 851 linalg.yield %2 : i32 852 } -> tensor<100x128x256xi32> 853 %init_pack = tensor.empty() : tensor<4x16x100x16x32xi32> 854 %4 = tensor.pack %reduction 855 outer_dims_perm = [1, 2, 0] 856 inner_dims_pos = [2, 1] 857 inner_tiles = [16, 32] 858 into %init_pack : tensor<100x128x256xi32> -> tensor<4x16x100x16x32xi32> 859 return %4 : tensor<4x16x100x16x32xi32> 860} 861 862// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> 863// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)> 864// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5)> 865// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d5)> 866// CHECK-LABEL: func.func @reduction_pack_with_outer_dims 867// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 868// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 869// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] 870// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] 871// CHECK: %[[ARG3_EMPTY:.+]] = tensor.empty() : tensor<4x16x100x16x32xi32> 872// CHECK: %[[PACKED_ARG3:.+]] = tensor.pack %[[ARG3]] 873// CHECK-SAME: outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 32] 874// CHECK-SAME: into %[[ARG3_EMPTY]] 875// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x200x100x16x32xi32> 876// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] 877// CHECK-SAME: outer_dims_perm = [1, 3, 2, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32] 878// CHECK-SAME: into %[[ARG0_EMPTY]] 879// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32> 880// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]] 881// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32] 882// CHECK-SAME: into %[[ARG2_EMPTY]] 883// CHECK: %[[RES:.+]] = linalg.generic 884// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] 885// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] 886// CHECK-SAME: outs(%[[PACKED_ARG3]] 887 888// ----- 889 890#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)> 891#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> 892#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)> 893func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>, 894 %filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{ 895 %init = tensor.empty() : tensor<16x540x960xi32> 896 %empty = tensor.empty() : tensor<1x16x1080x1920xi32> 897 %unpack = tensor.unpack %arg0 898 inner_dims_pos = [1] 899 inner_tiles = [16] 900 into %empty : tensor<1x1x1080x1920x16xi32> -> tensor<1x16x1080x1920xi32> 901 %pool = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} 902 ins(%unpack, %filter : tensor<1x16x1080x1920xi32>, tensor<2x2xi32>) 903 outs(%init : tensor<16x540x960xi32>) { 904 ^bb0(%in: i32, %in_1: i32, %out: i32): 905 %max = arith.maxui %in, %in_1 : i32 906 linalg.yield %max : i32 907 } -> tensor<16x540x960xi32> 908 return %pool : tensor<16x540x960xi32> 909} 910// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5, d6)> 911// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5)> 912// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d3, d6)> 913// CHECK-LABEL: func.func @unpack_different_destination_shape 914// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 915// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 916// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32> 917// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32> 918// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32> 919// CHECK: %[[PACK_ARG0:.+]] = tensor.pack 920// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] 921// CHECK-SAME: into %[[PACK_EMPTY]] 922// CHECK: %[[POOL:.+]] = linalg.generic 923// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] 924// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] 925// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]] 926// CHECK-SAME: outs(%[[INIT]] 927// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[POOL]] 928// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] 929// CHECK-SAME: into %[[FINAL_RES]] 930// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32> 931 932// ----- 933 934func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> { 935 %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32> 936 %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32> 937 %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32> 938 func.return %pack : tensor<?x4x8x1xf32> 939} 940// CHECK-LABEL: func.func @bubble_up_pack_through_collapse 941// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 942// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 943// CHECK: %[[C0:.+]] = arith.constant 0 : index 944// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32> 945// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32> 946// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> 947// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32> 948// CHECK: return %[[COLLAPSED]] : tensor<?x4x8x1xf32> 949 950// ----- 951 952func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> { 953 %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32> 954 %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32> 955 %pack = tensor.pack %collapsed inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32> 956 func.return %pack : tensor<?x4x8x1xf32> 957} 958// CHECK-LABEL: func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm 959// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 960// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 961// CHECK: %[[C0:.+]] = arith.constant 0 : index 962// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32> 963// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32> 964// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> 965// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32> 966// CHECK: return %[[COLLAPSED]] : tensor<?x4x8x1xf32> 967 968// ----- 969 970func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> { 971 %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32> 972 %2 = tensor.empty() : tensor<4x32x3072x8x1xf32> 973 %pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32> 974 func.return %pack : tensor<4x32x3072x8x1xf32> 975} 976// CHECK-LABEL: func.func @bubble_up_permuted_pack_through_collapse 977// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 978// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x32x192x16x8x1xf32> 979// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<4x192x16x256xf32> -> tensor<4x32x192x16x8x1xf32> 980// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %pack {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x192x16x8x1xf32> into tensor<4x32x3072x8x1xf32> 981// CHECK: return %[[COLLAPSED]] : tensor<4x32x3072x8x1xf32> 982 983// ----- 984 985func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> { 986 %collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32> 987 %2 = tensor.empty() : tensor<8x4x8x1xf32> 988 %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32> 989 func.return %pack : tensor<8x4x8x1xf32> 990} 991// CHECK-LABEL: func.func @bubble_up_pack_through_unit_collapse 992// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 993// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x4x8x1xf32> 994// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<1x64x1x4xf32> -> tensor<1x8x1x4x8x1xf32> 995// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<1x8x1x4x8x1xf32> into tensor<8x4x8x1xf32> 996// CHECK: return %[[COLLAPSED]] : tensor<8x4x8x1xf32> 997 998// ----- 999 1000func.func @bubble_up_pack_through_collapse_on_outer_dims(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x1x4xf32> { 1001 %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32> 1002 %2 = tensor.empty(%dim) : tensor<?x1x4xf32> 1003 %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [4] into %2 : tensor<?x4xf32> -> tensor<?x1x4xf32> 1004 func.return %pack : tensor<?x1x4xf32> 1005} 1006// CHECK-LABEL: func.func @bubble_up_pack_through_collapse_on_outer_dims 1007// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1008// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 1009// CHECK: %[[C0:.+]] = arith.constant 0 : index 1010// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32> 1011// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x16x1x4xf32> 1012// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [4] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x16x1x4xf32> 1013// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] : tensor<?x16x1x4xf32> into tensor<?x1x4xf32> 1014// CHECK: return %[[COLLAPSED]] : tensor<?x1x4xf32> 1015 1016// ----- 1017 1018func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> { 1019 %collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32> 1020 %2 = tensor.empty() : tensor<384x32x8x8xf32> 1021 %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32> 1022 func.return %pack : tensor<384x32x8x8xf32> 1023} 1024// CHECK-LABEL: func.func @no_bubble_up_pack_through_non_divisible_collapse 1025// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1026// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32> 1027// CHECK: %[[PACK:.+]] = tensor.pack %[[COLLAPSED]] 1028// CHECK: return %[[PACK]] : tensor<384x32x8x8xf32> 1029 1030// ----- 1031 1032func.func @bubble_up_pack_outer_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x64x4xf32> { 1033 %empty = tensor.empty() : tensor<4x2x64x4xf32> 1034 %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> 1035 %pack = tensor.pack %expanded inner_dims_pos = [1] inner_tiles = [4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x64x4xf32> 1036 return %pack : tensor<4x2x64x4xf32> 1037} 1038// CHECK-LABEL: func.func @bubble_up_pack_outer_expanded_through_expand( 1039// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1040// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x4xf32> 1041// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 1042// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x64x4xf32> 1043// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] 1044// CHECK-SAME: output_shape [4, 2, 64, 4] : tensor<8x64x4xf32> into tensor<4x2x64x4xf32> 1045// CHECK: return %[[EXPANDED]] : tensor<4x2x64x4xf32> 1046 1047// ----- 1048 1049func.func @bubble_up_pack_inner_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x4x4xf32> { 1050 %empty = tensor.empty() : tensor<32x4x4x4xf32> 1051 %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> 1052 %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [4] into %empty : tensor<32x4x16xf32> -> tensor<32x4x4x4xf32> 1053 return %pack : tensor<32x4x4x4xf32> 1054} 1055// CHECK-LABEL: func.func @bubble_up_pack_inner_expanded_through_expand( 1056// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1057// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x4xf32> 1058// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 1059// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] 1060// CHECK-SAME: : tensor<32x64xf32> -> tensor<32x16x4xf32> 1061// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] 1062// CHECK-SAME: output_shape [32, 4, 4, 4] : tensor<32x16x4xf32> into tensor<32x4x4x4xf32> 1063// CHECK: return %[[EXPANDED]] : tensor<32x4x4x4xf32> 1064 1065// ----- 1066 1067func.func @bubble_up_pack_non_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x32x16x4xf32> { 1068 %empty = tensor.empty() : tensor<8x2x32x16x4xf32> 1069 %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32> 1070 %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [4] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x32x16x4xf32> 1071 return %pack : tensor<8x2x32x16x4xf32> 1072} 1073// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_dims_through_expand( 1074// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1075// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x16x4xf32> 1076// CHECK: %[[PACK:.+]] = tensor.pack 1077// CHECK-SAME: %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] 1078// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32> 1079// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]] 1080// CHECK-SAME: output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32> 1081// CHECK: return %[[EXPANDED]] : tensor<8x2x32x16x4xf32> 1082 1083// ----- 1084 1085func.func @bubble_up_pack_through_expand_dynamic(%arg0: tensor<?x64xf32>) -> tensor<?x4x2x8xf32> { 1086 %c0 = arith.constant 0 : index 1087 %dim = tensor.dim %arg0, %c0 : tensor<?x64xf32> 1088 %empty = tensor.empty(%dim) : tensor<?x4x2x8xf32> 1089 %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%dim, 4, 16] : tensor<?x64xf32> into tensor<?x4x16xf32> 1090 %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [8] into %empty : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> 1091 return %pack : tensor<?x4x2x8xf32> 1092} 1093// CHECK-LABEL: func.func @bubble_up_pack_through_expand_dynamic( 1094// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1095// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 1096// CHECK: %[[DIM_INPUT:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x64xf32> 1097// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM_INPUT]]) : tensor<?x8x8xf32> 1098// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 1099// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] 1100// CHECK-SAME: : tensor<?x64xf32> -> tensor<?x8x8xf32> 1101// CHECK: %[[DIM_PACK:.+]] = tensor.dim %[[PACK]], %[[C0]] : tensor<?x8x8xf32> 1102// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] 1103// CHECK-SAME: output_shape [%[[DIM_PACK]], 4, 2, 8] : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> 1104// CHECK: return %[[EXPANDED]] : tensor<?x4x2x8xf32> 1105 1106// ----- 1107 1108func.func @bubble_up_pack_non_expanded_padding_through_expand(%arg0: tensor<32x60xf32>) -> tensor<4x2x8x4x8xf32> { 1109 %cst = arith.constant 3.000000e+00 : f32 1110 %empty = tensor.empty() : tensor<4x2x8x4x8xf32> 1111 %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x60xf32> into tensor<4x8x60xf32> 1112 %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1, 2] inner_tiles = [4, 8] into %empty : tensor<4x8x60xf32> -> tensor<4x2x8x4x8xf32> 1113 return %pack : tensor<4x2x8x4x8xf32> 1114} 1115// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_padding_through_expand( 1116// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1117// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32 1118// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32> 1119// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[CST]] : f32) 1120// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %[[EMPTY]] 1121// CHECK-SAME: : tensor<32x60xf32> -> tensor<8x8x4x8xf32> 1122// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] 1123// CHECK-SAME: output_shape [4, 2, 8, 4, 8] : tensor<8x8x4x8xf32> into tensor<4x2x8x4x8xf32> 1124// CHECK: return %[[EXPANDED]] : tensor<4x2x8x4x8xf32> 1125 1126// ----- 1127 1128func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x32x4x2xf32> { 1129 %empty = tensor.empty() : tensor<4x2x32x4x2xf32> 1130 %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> 1131 %pack = tensor.pack %expanded outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<4x2x32x4x2xf32> 1132 return %pack : tensor<4x2x32x4x2xf32> 1133} 1134// CHECK-LABEL: func.func @bubble_up_pack_outer_dims_perm_identity_through_expand( 1135// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1136// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32x4x2xf32> 1137// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 1138// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 2] into %[[EMPTY]] 1139// CHECK-SAME: : tensor<32x64xf32> -> tensor<8x32x4x2xf32> 1140// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] 1141// CHECK-SAME: output_shape [4, 2, 32, 4, 2] : tensor<8x32x4x2xf32> into tensor<4x2x32x4x2xf32> 1142// CHECK: return %[[EXPANDED]] : tensor<4x2x32x4x2xf32> 1143 1144// ----- 1145 1146func.func @bubble_up_pack_multiple_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x4x8x4x8x2xf32> { 1147 %empty = tensor.empty() : tensor<8x2x4x8x4x8x2xf32> 1148 %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32> 1149 %pack = tensor.pack %expanded inner_dims_pos = [0, 2, 3] inner_tiles = [4, 8, 2] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x4x8x4x8x2xf32> 1150 return %pack : tensor<8x2x4x8x4x8x2xf32> 1151} 1152// CHECK-LABEL: func.func @bubble_up_pack_multiple_dims_through_expand( 1153// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1154// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x8x4x8x2xf32> 1155// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 1156// CHECK-SAME: inner_dims_pos = [0, 1, 2] inner_tiles = [4, 8, 2] into %[[EMPTY]] 1157// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x8x8x4x8x2xf32> 1158// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4], [5], [6]] 1159// CHECK-SAME: output_shape [8, 2, 4, 8, 4, 8, 2] : tensor<8x8x8x4x8x2xf32> into tensor<8x2x4x8x4x8x2xf32> 1160// CHECK: return %[[EXPANDED]] : tensor<8x2x4x8x4x8x2xf32> 1161 1162// ----- 1163 1164func.func @bubble_up_pack_inner_dims_reorder_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x4x16x4xf32> { 1165 %empty = tensor.empty() : tensor<4x2x4x16x4xf32> 1166 %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> 1167 %pack = tensor.pack %expanded inner_dims_pos = [2, 1] inner_tiles = [16, 4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x4x16x4xf32> 1168 return %pack : tensor<4x2x4x16x4xf32> 1169} 1170// CHECK-LABEL: func.func @bubble_up_pack_inner_dims_reorder_through_expand( 1171// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1172// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x4xf32> 1173// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 1174// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] 1175// CHECK-SAME: : tensor<32x64xf32> -> tensor<8x4x16x4xf32> 1176// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] 1177// CHECK-SAME: output_shape [4, 2, 4, 16, 4] : tensor<8x4x16x4xf32> into tensor<4x2x4x16x4xf32> 1178// CHECK: return %[[EXPANDED]] : tensor<4x2x4x16x4xf32> 1179 1180// ----- 1181 1182func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<4x2x2x8x16x4x4xf32> { 1183 %empty = tensor.empty() : tensor<4x2x2x8x16x4x4xf32> 1184 %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [4, 8, 2, 32, 16] : tensor<32x64x16xf32> into tensor<4x8x2x32x16xf32> 1185 %pack = tensor.pack %expanded inner_dims_pos = [1, 3] inner_tiles = [4, 4] into %empty : tensor<4x8x2x32x16xf32> -> tensor<4x2x2x8x16x4x4xf32> 1186 return %pack : tensor<4x2x2x8x16x4x4xf32> 1187} 1188// CHECK-LABEL: func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand( 1189// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1190// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x16x16x4x4xf32> 1191// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] 1192// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %[[EMPTY]] 1193// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x16x16x4x4xf32> 1194// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2, 3], [4], [5], [6]] 1195// CHECK-SAME: output_shape [4, 2, 2, 8, 16, 4, 4] : tensor<8x16x16x4x4xf32> into tensor<4x2x2x8x16x4x4xf32> 1196// CHECK: return %[[EXPANDED]] : tensor<4x2x2x8x16x4x4xf32> 1197 1198// ----- 1199 1200func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x2x4x2xf32> { 1201 %empty = tensor.empty() : tensor<32x4x2x4x2xf32> 1202 %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> 1203 %pack = tensor.pack %expanded outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32> 1204 return %pack : tensor<32x4x2x4x2xf32> 1205} 1206// CHECK-LABEL: func.func @no_bubble_up_pack_outer_dims_permutation_through_expand( 1207// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1208// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x4x2x4x2xf32> 1209// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] 1210// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> 1211// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] 1212// CHECK-SAME: outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %[[EMPTY]] 1213// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32> 1214// CHECK: return %[[PACK]] : tensor<32x4x2x4x2xf32> 1215 1216// ----- 1217 1218func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x2x64x2x4xf32> { 1219 %empty = tensor.empty() : tensor<2x2x64x2x4xf32> 1220 %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> 1221 %pack = tensor.pack %expanded inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %empty : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32> 1222 return %pack : tensor<2x2x64x2x4xf32> 1223} 1224// CHECK-LABEL: func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand( 1225// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1226// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x64x2x4xf32> 1227// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] 1228// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> 1229// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] 1230// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %[[EMPTY]] 1231// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32> 1232// CHECK: return %[[PACK]] : tensor<2x2x64x2x4xf32> 1233 1234// ----- 1235 1236func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x8x64x2xf32> { 1237 %empty = tensor.empty() : tensor<2x8x64x2xf32> 1238 %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> 1239 %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [2] into %empty : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32> 1240 return %pack : tensor<2x8x64x2xf32> 1241} 1242// CHECK-LABEL: func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand( 1243// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1244// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x2xf32> 1245// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] 1246// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> 1247// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] 1248// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [2] into %[[EMPTY]] 1249// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32> 1250// CHECK: return %[[PACK]] : tensor<2x8x64x2xf32> 1251 1252// ----- 1253 1254func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(%arg0: tensor<30x60xf32>) -> tensor<3x2x60x8xf32> { 1255 %cst = arith.constant 3.000000e+00 : f32 1256 %empty = tensor.empty() : tensor<3x2x60x8xf32> 1257 %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32> 1258 %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [8] into %empty : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32> 1259 return %pack : tensor<3x2x60x8xf32> 1260} 1261// CHECK-LABEL: func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate( 1262// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1263// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32 1264// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x2x60x8xf32> 1265// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] 1266// CHECK-SAME: output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32> 1267// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] padding_value(%[[CST]] : f32) 1268// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] 1269// CHECK-SAME: : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32> 1270// CHECK: return %[[PACK]] : tensor<3x2x60x8xf32> 1271 1272// ----- 1273 1274func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> { 1275 %empty = tensor.empty() : tensor<8x4x16x8xf32> 1276 %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> 1277 %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> 1278 return %pack : tensor<8x4x16x8xf32> 1279} 1280// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate( 1281// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1282// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32> 1283// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] 1284// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> 1285// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] 1286// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]] 1287// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> 1288// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32> 1289 1290// ----- 1291 1292func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> { 1293 %6 = tensor.empty(%dim) : tensor<?x256xf32> 1294 %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32> 1295 %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32> 1296 func.return %expanded : tensor<?x256x256xf32> 1297} 1298// CHECK-LABEL: func.func @push_down_unpack_through_expand 1299// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1300// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 1301// CHECK: %[[C32:.+]] = arith.constant 32 : index 1302// CHECK: %[[C0:.+]] = arith.constant 0 : index 1303// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32> 1304// CHECK: %[[SZ0:.+]] = arith.divsi %[[DIM0]], %[[C32]] : index 1305// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> 1306// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32> 1307// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32> 1308// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> 1309// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32> 1310 1311// ----- 1312 1313func.func @push_down_unpack_through_expand_empty_outer_dims_perm(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> { 1314 %6 = tensor.empty(%dim) : tensor<?x256xf32> 1315 %unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32> 1316 %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32> 1317 func.return %expanded : tensor<?x256x256xf32> 1318} 1319// CHECK-LABEL: func.func @push_down_unpack_through_expand_empty_outer_dims_perm 1320// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1321// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 1322// CHECK: %[[C32:.+]] = arith.constant 32 : index 1323// CHECK: %[[C0:.+]] = arith.constant 0 : index 1324// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32> 1325// CHECK: %[[SZ0:.+]] = arith.divsi %[[DIM0]], %[[C32]] : index 1326// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> 1327// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32> 1328// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32> 1329// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> 1330// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32> 1331 1332// ----- 1333 1334func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> { 1335 %6 = tensor.empty() : tensor<4x3072x256xf32> 1336 %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32> 1337 %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] output_shape [4, 12, 256, 256] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32> 1338 func.return %expanded : tensor<4x12x256x256xf32> 1339} 1340// CHECK-LABEL: @push_down_permuted_unpack_through_expand 1341// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1342// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] output_shape [4, 32, 12, 32, 8, 8] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32> 1343// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32> 1344// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32> 1345// CHECK: return %[[UNPACK]] : tensor<4x12x256x256xf32> 1346 1347// ----- 1348 1349func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> { 1350 %6 = tensor.empty() : tensor<48x256xf32> 1351 %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<6x32x8x8xf32> -> tensor<48x256xf32> 1352 %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] output_shape [3, 16, 1, 256] : tensor<48x256xf32> into tensor<3x16x1x256xf32> 1353 func.return %expanded : tensor<3x16x1x256xf32> 1354} 1355// CHECK-LABEL: func.func @push_down_unpack_through_unit_expand 1356// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1357// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] output_shape [3, 2, 1, 32, 8, 8] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32> 1358// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x16x1x256xf32> 1359// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<3x2x1x32x8x8xf32> -> tensor<3x16x1x256xf32> 1360// CHECK: return %[[UNPACK]] : tensor<3x16x1x256xf32> 1361 1362// ----- 1363 1364func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> { 1365 %6 = tensor.empty(%dim) : tensor<?x256xf32> 1366 %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [8] into %6 : tensor<?x32x8xf32> -> tensor<?x256xf32> 1367 %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32> 1368 func.return %expanded : tensor<?x256x256xf32> 1369} 1370// CHECK-LABEL: func.func @push_down_unpack_through_expand_on_outer_dims 1371// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1372// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] 1373// CHECK: %[[C256:.+]] = arith.constant 256 : index 1374// CHECK: %[[C0:.+]] = arith.constant 0 : index 1375// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8xf32> 1376// CHECK: %[[SZ0:.+]] = arith.divsi %[[DIM0]], %[[C256]] : index 1377// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] output_shape [%[[SZ0]], 256, 32, 8] : tensor<?x32x8xf32> into tensor<?x256x32x8xf32> 1378// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x256x32x8xf32> 1379// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32> 1380// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [8] into %[[EMPTY]] : tensor<?x256x32x8xf32> -> tensor<?x256x256xf32> 1381// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32> 1382 1383// ----- 1384 1385func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> { 1386 %6 = tensor.empty() : tensor<3072x256xf32> 1387 %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32> 1388 %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32> 1389 func.return %expanded : tensor<256x12x256xf32> 1390} 1391// CHECK-LABEL: func.func @no_push_down_unpack_through_non_divisible_expand 1392// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] 1393// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 1394// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32> 1395// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32> 1396