1// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-expansion -split-input-file | FileCheck %s 2 3#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> 4#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> 5#map2 = affine_map<(d0, d1, d2) -> ()> 6func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>, 7 %arg1 : tensor<?x?x?xf32>, 8 %arg2 : f32) -> 9 tensor<?x?x?xf32> 10{ 11 %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : 12 tensor<?x?x4x?xf32> into tensor<?x?x?xf32> 13 %1 = linalg.generic { 14 indexing_maps = [#map0, #map1, #map2, #map1], 15 iterator_types = ["parallel", "parallel", "parallel"]} 16 ins(%0, %arg1, %arg2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) 17 outs(%arg1 : tensor<?x?x?xf32>) { 18 ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32): 19 %1 = arith.mulf %arg3, %arg4 : f32 20 %2 = arith.addf %1, %arg5 : f32 21 linalg.yield %2 : f32 22 } -> tensor<?x?x?xf32> 23 return %1 : tensor<?x?x?xf32> 24} 25 26// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> 27// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)> 28// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3) -> ()> 29// CHECK: func @generic_op_reshape_producer_fusion 30// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32> 31// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 32// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 33// CHECK: %[[C4:.+]] = arith.constant 4 : index 34// CHECK: %[[C2:.+]] = arith.constant 2 : index 35// CHECK: %[[C1:.+]] = arith.constant 1 : index 36// CHECK: %[[C0:.+]] = arith.constant 0 : index 37// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32> 38// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32> 39// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32> 40// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_1]], %[[C4]] : index 41// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32> 42// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32> 43// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32> 44// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32> 45// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_4]], %[[C4]] : index 46// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32> 47// CHECK: %[[T3:.+]] = linalg.generic 48// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]] 49// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] 50// CHECK-SAME: ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>, f32) 51// CHECK-SAME: outs(%[[T2]] : tensor<?x?x?x4xf32>) 52// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]] 53// CHECK-SAME: [0], [1], [2, 3] 54// CHECK-SAME: tensor<?x?x?x4xf32> into tensor<?x?x?xf32> 55// CHECK: return %[[T4]] 56 57// ----- 58 59#map0 = affine_map<(d0, d1) -> (d0, d1)> 60#map1 = affine_map<(d0, d1) -> ()> 61func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>, 62 %arg1 : tensor<?x?xf32>, 63 %arg2 : f32, 64 %sz0: index, 65 %sz1: index) -> 66 tensor<?x4x?x5xf32> 67{ 68 %0 = linalg.generic { 69 indexing_maps = [#map0, #map0, #map1, #map0], 70 iterator_types = ["parallel", "parallel"]} 71 ins(%arg0, %arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>, f32) 72 outs(%arg0 : tensor<?x?xf32>) { 73 ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32): 74 %1 = arith.mulf %arg3, %arg4 : f32 75 %2 = arith.addf %1, %arg5 : f32 76 linalg.yield %2 : f32 77 } -> tensor<?x?xf32> 78 %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, 4, %sz1, 5] : 79 tensor<?x?xf32> into tensor<?x4x?x5xf32> 80 return %1 : tensor<?x4x?x5xf32> 81} 82 83// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 84// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> ()> 85 86// CHECK: func @generic_op_reshape_consumer_fusion 87// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 88// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 89// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 90// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index 91// CHECK: %[[C20:.+]] = arith.constant 20 : index 92// CHECK: %[[C1:.+]] = arith.constant 1 : index 93// CHECK: %[[C0:.+]] = arith.constant 0 : index 94// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32> 95// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32> 96// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index 97// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32> 98// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32> 99// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32> 100// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index 101// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32> 102// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32> 103// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32> 104// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index 105// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32> 106// CHECK: %[[T3:.+]] = linalg.generic 107// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]] 108// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] 109// CHECK-SAME: ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>, f32) 110// CHECK-SAME: outs(%[[T2]] : tensor<?x4x?x5xf32>) 111// CHECK: return %[[T3]] : tensor<?x4x?x5xf32> 112 113 114// ----- 115 116func.func @reshape_as_consumer_permutation 117 (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index, %sz1: index, %sz2: index) 118 -> tensor<?x2x?x3x4x?xf32> { 119 %c = linalg.generic { 120 indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 121 affine_map<(d0, d1, d2) -> (d1, d2)>, 122 affine_map<(d0, d1, d2) -> (d0, d2, d1)>], 123 iterator_types = ["parallel", "parallel", "parallel"]} 124 ins(%a, %b : tensor<?x?x?xf32>, tensor<?x?xf32>) 125 outs(%a : tensor<?x?x?xf32>) { 126 ^bb0(%arg0 : f32, %arg1: f32, %s: f32): 127 %1 = arith.addf %arg0, %arg1 : f32 128 linalg.yield %1 : f32 129 } -> tensor<?x?x?xf32> 130 %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [%sz0, 2, %sz1, 3, 4, %sz2] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32> 131 return %d : tensor<?x2x?x3x4x?xf32> 132} 133// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> 134// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> 135// CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> 136// CHECK: func @reshape_as_consumer_permutation 137// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 138// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 139// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index 140// CHECK: %[[C12:.+]] = arith.constant 12 : index 141// CHECK: %[[C2:.+]] = arith.constant 2 : index 142// CHECK: %[[C1:.+]] = arith.constant 1 : index 143// CHECK: %[[C0:.+]] = arith.constant 0 : index 144// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32> 145// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32> 146// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32> 147// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index 148// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index 149// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32> 150// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32> 151// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32> 152// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index 153// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32> 154// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32> 155// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32> 156// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32> 157// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index 158// CHECK: %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index 159// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32> 160// CHECK: %[[T3:.+]] = linalg.generic 161// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]] 162// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] 163// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>) 164// CHECK-SAME: outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>) 165// CHECK: return %[[T3]] : tensor<?x2x?x3x4x?xf32> 166 167// ----- 168 169#map0 = affine_map<(d0, d1) -> (d0, d1)> 170#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> 171#map2 = affine_map<(d0, d1, d2) -> (d2)> 172 173func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) 174 -> tensor<8x33x4xf32> { 175 %cst = arith.constant dense<2.000000e+00> : tensor<264x4xf32> 176 %0 = tensor.empty() : tensor<264x4xf32> 177 %1 = linalg.generic { 178 indexing_maps = [#map0, #map0, #map0], 179 iterator_types = ["parallel", "parallel"]} 180 ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) 181 outs(%0 : tensor<264x4xf32>) { 182 ^bb0(%arg1: f32, %arg2: f32, %s: f32): 183 %2 = arith.mulf %arg1, %arg2 : f32 184 linalg.yield %2 : f32 185 } -> tensor<264x4xf32> 186 %2 = tensor.expand_shape %1 [[0, 1], [2]] output_shape [8, 33, 4] : 187 tensor<264x4xf32> into tensor<8x33x4xf32> 188 return %2 : tensor<8x33x4xf32> 189} 190 191// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 192// CHECK: func @generic_op_reshape_consumer_static 193// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32> 194// CHECK-DAG: %[[CST:.+]] = arith.constant 195// CHECK-SAME: : tensor<8x33x4xf32> 196// CHECK-DAG: %[[INIT:.+]] = tensor.empty() 197// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> 198// CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> 199// CHECK: %[[T2:.+]] = linalg.generic 200// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] 201// CHECK-SAME: ["parallel", "parallel", "parallel"] 202// CHECK-SAME: ins(%[[T0]], %[[CST]] : 203// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>) 204// CHECK: return %[[T2]] : tensor<8x33x4xf32> 205 206// ----- 207 208#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> 209#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> 210func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>, 211 %arg1 : tensor<?x?x?xi32>) -> 212 tensor<?x?x?xi32> 213{ 214 %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3]]: 215 tensor<?x?x4x?xi32> into tensor<?x?x?xi32> 216 %1 = linalg.generic { 217 indexing_maps = [#map0, #map1, #map1], 218 iterator_types = ["parallel", "parallel", "parallel"]} 219 ins(%0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xi32>) 220 outs(%0 : tensor<?x?x?xi32>) { 221 ^bb0(%arg3: i32, %arg4: i32, %s: i32): 222 %idx0 = linalg.index 0 : index 223 %idx1 = linalg.index 1 : index 224 %idx2 = linalg.index 2 : index 225 %1 = arith.muli %arg3, %arg4 : i32 226 %2 = arith.index_cast %idx0 : index to i32 227 %3 = arith.addi %1, %2 : i32 228 %4 = arith.index_cast %idx1 : index to i32 229 %5 = arith.addi %3, %4 : i32 230 %6 = arith.index_cast %idx2 : index to i32 231 %7 = arith.addi %5, %6 : i32 232 linalg.yield %7 : i32 233 } -> tensor<?x?x?xi32> 234 return %1 : tensor<?x?x?xi32> 235} 236 237// Only check the body in the indexed version of the test. 238// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)> 239// CHECK: func @indexed_consumer_reshape_producer_fusion 240// CHECK: linalg.generic 241// CHECK: ^{{.*}}( 242// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: i32, %[[ARG4:[a-zA-Z0-9_]+]]: i32, 243// CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: i32) 244// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index 245// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index 246// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index 247// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index 248// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX1]], %[[IDX0]]) 249// CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]] 250// CHECK: %[[T5:.+]] = arith.index_cast %[[T3]] 251// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]] 252// CHECK: %[[T7:.+]] = arith.index_cast %[[IDX2]] 253// CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]] 254// CHECK: %[[T9:.+]] = arith.index_cast %[[IDX3]] 255// CHECK: %[[T10:.+]] = arith.addi %[[T8]], %[[T9]] 256// CHECK: linalg.yield %[[T10]] 257 258// ----- 259 260#map0 = affine_map<(d0, d1) -> (d0, d1)> 261func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>, 262 %arg1 : tensor<?x?xi32>, 263 %sz0: index, %sz1: index) -> 264 tensor<?x?x4x5xi32> 265{ 266 %0 = linalg.generic { 267 indexing_maps = [#map0, #map0, #map0], 268 iterator_types = ["parallel", "parallel"]} 269 ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) 270 outs(%arg0 : tensor<?x?xi32>) { 271 ^bb0(%arg3: i32, %arg4: i32, %s: i32): 272 %idx0 = linalg.index 0 : index 273 %idx1 = linalg.index 1 : index 274 %1 = arith.muli %arg3, %arg4 : i32 275 %2 = arith.index_cast %idx0 : index to i32 276 %3 = arith.addi %1, %2 : i32 277 %4 = arith.index_cast %idx1 : index to i32 278 %5 = arith.addi %3, %4 : i32 279 linalg.yield %5 : i32 280 } -> tensor<?x?xi32> 281 %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] : 282 tensor<?x?xi32> into tensor<?x?x4x5xi32> 283 return %1 : tensor<?x?x4x5xi32> 284} 285 286// Only check the body in the indexed version of the test. 287// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)> 288// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)> 289// CHECK: func @indexed_producer_reshape_consumer_fusion 290// CHECK: linalg.generic 291// CHECK: ^{{.*}}( 292// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: i32, %[[ARG4:[a-zA-Z0-9_]+]]: i32, 293// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: i32) 294// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index 295// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index 296// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index 297// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index 298// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]]) 299// CHECK: %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]]) 300// CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]] 301// CHECK: %[[T5:.+]] = arith.index_cast %[[IDX0]] 302// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]] 303// CHECK: %[[T7:.+]] = arith.index_cast %[[T2]] 304// CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]] 305// CHECK: linalg.yield %[[T8]] 306 307// ----- 308 309func.func @reshape_as_consumer_permutation 310 (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>) 311 -> tensor<2x3x4x5x6x7xi32> { 312 %shape = tensor.empty() : tensor<6x4x210xi32> 313 %c = linalg.generic { 314 indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 315 affine_map<(d0, d1, d2) -> (d1, d2)>, 316 affine_map<(d0, d1, d2) -> (d0, d2, d1)>], 317 iterator_types = ["parallel", "parallel", "parallel"]} 318 ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>) 319 outs(%shape : tensor<6x4x210xi32>) { 320 ^bb0(%arg3 : i32, %arg4: i32, %s: i32): 321 %idx0 = linalg.index 0 : index 322 %idx1 = linalg.index 1 : index 323 %idx2 = linalg.index 2 : index 324 %1 = arith.addi %arg3, %arg4 : i32 325 %2 = arith.index_cast %idx0 : index to i32 326 %3 = arith.addi %1, %2 : i32 327 %4 = arith.index_cast %idx1 : index to i32 328 %5 = arith.addi %3, %4 : i32 329 %6 = arith.index_cast %idx2 : index to i32 330 %7 = arith.addi %5, %6 : i32 331 linalg.yield %7 : i32 332 } -> tensor<6x4x210xi32> 333 %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> 334 return %d : tensor<2x3x4x5x6x7xi32> 335} 336 337// ----- 338 339// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> 340// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> 341// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> 342// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> 343// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)> 344// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)> 345// CHECK: func @reshape_as_consumer_permutation 346// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> 347// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> 348// CHECK-DAG: %[[INIT:.+]] = tensor.empty() 349// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [5, 6, 7, 2, 3, 4] : tensor<210x6x4xi32> into tensor<5x6x7x2x3x4xi32> 350// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [5, 6, 7, 4] : tensor<210x4xi32> into tensor<5x6x7x4xi32> 351// CHECK: %[[T3:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> 352// CHECK: %[[T4:.+]] = linalg.generic 353// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 354// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>) 355// CHECK-SAME: outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>) 356// CHECK: ^{{.+}}( 357// CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: i32, %[[ARG9:[a-zA-Z0-9_]+]]: i32, 358// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: i32) 359// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index 360// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index 361// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index 362// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index 363// CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index 364// CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index 365// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]]) 366// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]]) 367// CHECK-DAG: %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]]) 368// CHECK-DAG: %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]] 369// CHECK: %[[T9:.+]] = arith.index_cast %[[T5]] 370// CHECK: %[[T10:.+]] = arith.addi %[[T8]], %[[T9]] 371// CHECK: %[[T11:.+]] = arith.index_cast %[[T7]] 372// CHECK: %[[T12:.+]] = arith.addi %[[T10]], %[[T11]] 373// CHECK: %[[T13:.+]] = arith.index_cast %[[IDX5]] 374// CHECK: %[[T14:.+]] = arith.addi %[[T12]], %[[T13]] 375 376// ----- 377 378func.func @reshape_as_producer_projected_permutation( 379 %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32> 380{ 381 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] 382 : tensor<33x8x?xi32> into tensor<264x?xi32> 383 %1 = linalg.generic 384 {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, 385 affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 386 iterator_types = ["parallel", "parallel", "parallel"]} 387 ins(%0 : tensor<264x?xi32>) 388 outs(%shape : tensor<264x?x4xi32>) { 389 ^bb0(%arg1: i32, %s: i32): 390 %idx0 = linalg.index 0 : index 391 %idx1 = linalg.index 1 : index 392 %idx2 = linalg.index 2 : index 393 %2 = arith.index_cast %idx0 : index to i32 394 %3 = arith.addi %arg1, %2 : i32 395 %4 = arith.index_cast %idx1 : index to i32 396 %5 = arith.addi %3, %4 : i32 397 %6 = arith.index_cast %idx2 : index to i32 398 %7 = arith.addi %5, %6 : i32 399 linalg.yield %7 : i32 400 } -> tensor<264x?x4xi32> 401 return %1 : tensor<264x?x4xi32> 402} 403 404// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 405// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 406// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)> 407// CHECK: @reshape_as_producer_projected_permutation 408// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32> 409// CHECK: %[[RES:.+]] = linalg.generic 410// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] 411// CHECK-SAME: ins(%[[ARG0]] : tensor<33x8x?xi32>) 412// CHECK: ^{{.+}}( 413// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: i32, 414// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: i32) 415// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index 416// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index 417// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index 418// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index 419// CHECK-DAG: %[[T0:.+]] = affine.apply #[[MAP2]](%[[IDX1]], %[[IDX0]]) 420// CHECK: %[[T1:.+]] = arith.index_cast %[[T0]] : index to i32 421// CHECK: %[[T2:.+]] = arith.addi %[[ARG1]], %[[T1]] : i32 422// CHECK: %[[T3:.+]] = arith.index_cast %[[IDX2]] : index to i32 423// CHECK: %[[T4:.+]] = arith.addi %[[T2]], %[[T3]] : i32 424// CHECK: %[[T5:.+]] = arith.index_cast %[[IDX3]] : index to i32 425// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]] : i32 426// CHECK: linalg.yield %[[T6]] : i32 427// CHECK: %[[RES2:.+]] = tensor.collapse_shape %[[RES]] 428// CHECK-SAME: [0, 1], [2], [3] 429// CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32> 430// CHECK: return %[[RES2]] : tensor<264x?x4xi32> 431 432// ----- 433 434#map0 = affine_map<(d0, d1) -> (d0, d1)> 435#map1 = affine_map<(d0, d1) -> (d1, d0)> 436func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>, 437 %arg1 : tensor<?x?xf32>, 438 %sz0: index, %sz1: index) -> 439 tensor<?x?x4x5xf32> 440{ 441 %0 = linalg.generic { 442 indexing_maps = [#map0, #map0, #map1], 443 iterator_types = ["parallel", "parallel"]} 444 ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 445 outs(%arg0 : tensor<?x?xf32>) { 446 ^bb0(%arg3: f32, %arg4: f32, %s: f32): 447 %1 = arith.mulf %arg3, %arg4 : f32 448 linalg.yield %1 : f32 449 } -> tensor<?x?xf32> 450 %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] : 451 tensor<?x?xf32> into tensor<?x?x4x5xf32> 452 return %1 : tensor<?x?x4x5xf32> 453} 454 455// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 456// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> 457// CHECK: func @generic_op_reshape_consumer_fusion_projected 458// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 459// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 460// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index 461// CHECK: %[[C20:.+]] = arith.constant 20 : index 462// CHECK: %[[C1:.+]] = arith.constant 1 : index 463// CHECK: %[[C0:.+]] = arith.constant 0 : index 464// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32> 465// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32> 466// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index 467// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32> 468// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32> 469// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32> 470// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_1]], %[[C20]] : index 471// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_1]], 4, 5, %[[DIM_2]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32> 472// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32> 473// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32> 474// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index 475// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32> 476// CHECK: %[[T3:.+]] = linalg.generic 477// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]] 478// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] 479// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>) 480// CHECK-SAME: outs(%[[T2]] : tensor<?x?x4x5xf32>) 481// CHECK: return %[[T3]] : tensor<?x?x4x5xf32> 482 483// ----- 484 485func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> { 486 %c0 = arith.constant 0 : index 487 %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32> 488 %1 = tensor.dim %0, %c0 : tensor<?xf32> 489 %2 = tensor.empty(%1) : tensor<?xf32> 490 %3 = linalg.generic { 491 indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], 492 iterator_types = ["parallel"]} 493 ins(%0 : tensor<?xf32>) outs(%2 : tensor<?xf32>) { 494 ^bb0(%arg1 : f32, %arg2: f32): 495 %4 = arith.addf %arg1, %arg1 : f32 496 linalg.yield %4 : f32 497 } -> tensor<?xf32> 498 return %3 : tensor<?xf32> 499} 500 501// CHECK: func @no_fuse_dynamic_dims 502// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> 503// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] 504// CHECK: %[[GENERIC:.+]] = linalg.generic 505// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>) 506// CHECK: return %[[GENERIC]] 507 508// ----- 509 510func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi64>) -> tensor<2xi64> { 511 %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64> 512 %1 = tensor.empty() : tensor<2xi64> 513 %2 = linalg.generic 514 {indexing_maps = [affine_map<(d0) -> (d0)>, 515 affine_map<(d0) -> (d0)>, 516 affine_map<(d0) -> (d0)>], 517 iterator_types = ["parallel"]} 518 ins(%0, %arg1 : tensor<2xi64>, tensor<?xi64>) 519 outs(%1 : tensor<2xi64>) { 520 ^bb0(%arg4: i64, %arg5: i64, %arg6: i64): 521 %3 = arith.addi %arg4, %arg5 : i64 522 linalg.yield %3 : i64 523 } -> tensor<2xi64> 524 return %2 : tensor<2xi64> 525} 526 527// CHECK: func @no_fuse_mismatched_dynamism 528// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64> 529// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64> 530// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] 531// CHECK: %[[GENERIC:.+]] = linalg.generic 532// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>) 533// CHECK: return %[[GENERIC]] 534 535// ----- 536 537func.func @reshape_as_consumer_permutation_with_multiple_results 538 (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index, 539 %sz1: index, %sz2: index, %sz3: index, %sz4: index) 540 -> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) { 541 %c:2 = linalg.generic { 542 indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 543 affine_map<(d0, d1, d2) -> (d1, d2)>, 544 affine_map<(d0, d1, d2) -> (d0, d2, d1)>, 545 affine_map<(d0, d1, d2) -> (d2, d0, d1)>], 546 iterator_types = ["parallel", "parallel", "parallel"]} 547 ins(%a, %b : tensor<?x?x?xf32>, tensor<?x?xf32>) 548 outs(%a, %a : tensor<?x?x?xf32>, tensor<?x?x?xf32>) { 549 ^bb0(%arg0 : f32, %arg1: f32, %s: f32, %t : f32): 550 %1 = arith.addf %arg0, %arg1 : f32 551 linalg.yield %1, %1 : f32, f32 552 } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) 553 %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]] output_shape [%sz0, 2, %sz1, 3, 4, %sz2] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32> 554 %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]] output_shape [%sz3, %sz4, 2, 3, 4, %sz2] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32> 555 return %d, %e : tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32> 556} 557// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> 558// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> 559// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> 560// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d1, d2, d3, d4)> 561// CHECK: func @reshape_as_consumer_permutation_with_multiple_results 562// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32> 563// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32> 564// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index 565// CHECK: %[[C12:.+]] = arith.constant 12 : index 566// CHECK: %[[C2:.+]] = arith.constant 2 : index 567// CHECK: %[[C1:.+]] = arith.constant 1 : index 568// CHECK: %[[C0:.+]] = arith.constant 0 : index 569// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32> 570// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32> 571// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32> 572// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index 573// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index 574// CHECK: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32> 575// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32> 576// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32> 577// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index 578// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32> 579// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32> 580// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32> 581// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32> 582// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index 583// CHECK: %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index 584// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32> 585// CHECK: %[[DIM_9:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32> 586// CHECK: %[[DIM_10:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32> 587// CHECK: %[[DIM_11:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32> 588// CHECK: %[[VAL_5:.+]] = arith.divsi %[[DIM_10]], %[[C2]] : index 589// CHECK: %[[VAL_6:.+]] = arith.divsi %[[DIM_11]], %[[C12]] : index 590// CHECK: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[DIM_9]], %[[VAL_5]], 2, 3, 4, %[[VAL_6]]] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32> 591// CHECK: %[[GENERIC:.+]]:2 = linalg.generic 592// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] 593// CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] : 594// CHECK-SAME: outs(%[[RESHAPE2]], %[[RESHAPE3]] : 595// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 596 597// ----- 598 599#map0 = affine_map<(d0, d1) -> (d1)> 600#map1 = affine_map<(d0, d1) -> (d0, d1)> 601module { 602 func.func @multi_result_op_expansion(%arg0: tensor<512xf32>, %arg1: tensor<512xf32>, 603 %arg2: tensor<512xf32>, %arg3: tensor<200x512xf32>) -> tensor<25x8x1x512xf32> { 604 %0:2 = linalg.generic { 605 indexing_maps = [#map0, #map0, #map0, #map1], 606 iterator_types = ["parallel", "parallel"]} 607 ins(%arg0, %arg1 : tensor<512xf32>, tensor<512xf32>) 608 outs(%arg2, %arg3 : tensor<512xf32>, tensor<200x512xf32>) { 609 ^bb0(%arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32): 610 %2 = arith.addf %arg4, %arg5 : f32 611 linalg.yield %2, %2 : f32, f32 612 } -> (tensor<512xf32>, tensor<200x512xf32>) 613 %1 = tensor.expand_shape %0#1 [[0, 1, 2], [3]] output_shape [25, 8, 1, 512] : tensor<200x512xf32> into tensor<25x8x1x512xf32> 614 return %1 : tensor<25x8x1x512xf32> 615 } 616} 617// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> 618// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 619// CHECK: func.func @multi_result_op_expansion( 620// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<512xf32> 621// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512xf32> 622// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<512xf32> 623// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<200x512xf32> 624// CHECK: %[[OUTS:.+]] = tensor.expand_shape %[[ARG3]] {{\[\[}}0, 1, 2], [3]] output_shape [25, 8, 1, 512] : tensor<200x512xf32> into tensor<25x8x1x512xf32> 625// CHECK: %[[GENERIC:.+]]:2 = linalg.generic 626// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]], #[[MAP1]]] 627// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : 628// CHECK-SAME: outs(%[[ARG2]], %[[OUTS]] : 629// CHECK: return %[[GENERIC]]#1 630 631// ----- 632 633#map0 = affine_map<(d0, d1, d2) -> (d0, d2)> 634#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> 635#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> 636func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>, 637 %arg1 : tensor<?x?xf32>, 638 %arg2 : tensor<?x?xf32>, 639 %sz0: index, 640 %sz1: index) -> 641 tensor<?x?x4x5xf32> 642{ 643 %0 = linalg.generic { 644 indexing_maps = [#map0, #map1, #map2], 645 iterator_types = ["parallel", "parallel", "reduction"]} 646 ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 647 outs(%arg2 : tensor<?x?xf32>) { 648 ^bb0(%arg3: f32, %arg4: f32, %s: f32): 649 %1 = arith.mulf %arg3, %arg4 : f32 650 linalg.yield %1 : f32 651 } -> tensor<?x?xf32> 652 %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] : 653 tensor<?x?xf32> into tensor<?x?x4x5xf32> 654 return %1 : tensor<?x?x4x5xf32> 655} 656 657// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4)> 658// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)> 659// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> 660// CHECK: func @generic_op_reshape_consumer_fusion_reduction 661// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 662// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 663// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 664// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index 665// CHECK: %[[C20:.+]] = arith.constant 20 : index 666// CHECK: %[[C1:.+]] = arith.constant 1 : index 667// CHECK: %[[C0:.+]] = arith.constant 0 : index 668// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32> 669// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32> 670// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index 671// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32> 672// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32> 673// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32> 674// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index 675// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32> 676// CHECK: %[[T3:.+]] = linalg.generic 677// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 678// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "reduction"] 679// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x?xf32>, tensor<?x4x5x?xf32>) 680// CHECK-SAME: outs(%[[T2]] : tensor<?x?x4x5xf32>) 681// CHECK: return %[[T3]] : tensor<?x?x4x5xf32> 682 683// ----- 684 685#map0 = affine_map<(d0, d1, d2) -> (d2, d0)> 686#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 687#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> 688func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x?x8xf32>, 689 %arg1 : tensor<?x4x?xf32>, 690 %arg2 : tensor<?x?xf32>) -> 691 tensor<?x?xf32> 692{ 693 %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] : 694 tensor<?x7x?x8xf32> into tensor<?x?xf32> 695 %1 = linalg.generic { 696 indexing_maps = [#map0, #map1, #map2], 697 iterator_types = ["parallel", "reduction", "parallel"]} 698 ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x4x?xf32>) 699 outs(%arg2 : tensor<?x?xf32>) { 700 ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 701 %1 = arith.mulf %arg3, %arg4 : f32 702 %2 = arith.addf %1, %arg5 : f32 703 linalg.yield %2 : f32 704 } -> tensor<?x?xf32> 705 return %1 : tensor<?x?xf32> 706} 707 708// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d0, d1)> 709// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 710// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> 711// CHECK: func @generic_op_reshape_producer_fusion_with_reduction 712// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32> 713// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32> 714// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 715// CHECK: %[[C1:.+]] = arith.constant 1 : index 716// CHECK: %[[C7:.+]] = arith.constant 7 : index 717// CHECK: %[[C8:.+]] = arith.constant 8 : index 718// CHECK: %[[C2:.+]] = arith.constant 2 : index 719// CHECK: %[[C0:.+]] = arith.constant 0 : index 720// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x4x?xf32> 721// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x4x?xf32> 722// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C8]] : index 723// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index 724// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[VAL_0]], 8, 4, %[[VAL_1]], 7] : tensor<?x4x?xf32> into tensor<?x8x4x?x7xf32> 725// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32> 726// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32> 727// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index 728// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C7]] : index 729// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32> 730// CHECK: %[[T3:.+]] = linalg.generic 731// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] 732// CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel"] 733// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x8x4x?x7xf32>) 734// CHECK-SAME: outs(%[[T2]] : tensor<?x8x?x7xf32>) 735// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]] 736// CHECK-SAME: [0, 1], [2, 3] 737// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32> 738// CHECK: return %[[T4]] 739 740// ----- 741 742func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>, 743 %arg1 : tensor<?x?xf32>, 744 %arg2 : tensor<?x?xf32>, 745 %sz0: index, 746 %sz1: index) -> 747 tensor<?x?x4x5xf32> 748{ 749 %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 750 outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> 751 %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] : 752 tensor<?x?xf32> into tensor<?x?x4x5xf32> 753 return %1 : tensor<?x?x4x5xf32> 754} 755 756// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 757// CHECK: func @linalg_add_reshape_consumer_fusion 758// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 759// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 760// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 761// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index 762// CHECK: %[[C20:.+]] = arith.constant 20 : index 763// CHECK: %[[C1:.+]] = arith.constant 1 : index 764// CHECK: %[[C0:.+]] = arith.constant 0 : index 765// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32> 766// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32> 767// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index 768// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], %[[VAL_0]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32> 769// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32> 770// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32> 771// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index 772// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32> 773// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32> 774// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32> 775// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index 776// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32> 777// CHECK: %[[T4:.+]] = linalg.generic 778// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] 779// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] 780// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>) 781// CHECK-SAME: outs(%[[T3]] : tensor<?x?x4x5xf32>) 782// CHECK: return %[[T4]] : tensor<?x?x4x5xf32> 783 784// ----- 785 786#map0 = affine_map<(d0, d1, d2) -> (d2, d0)> 787#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 788#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> 789func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>, 790 %arg1 : tensor<?x?xf32>, 791 %arg2 : tensor<?x?xf32>) -> 792 tensor<?x?xf32> 793{ 794 %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] : 795 tensor<?x7x?x8xf32> into tensor<?x?xf32> 796 %1 = linalg.add ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 797 outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> 798 return %1 : tensor<?x?xf32> 799} 800 801// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 802// CHECK: func @linalg_add_reshape_producer_fusion 803// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32> 804// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 805// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32> 806// CHECK: %[[C8:.+]] = arith.constant 8 : index 807// CHECK: %[[C7:.+]] = arith.constant 7 : index 808// CHECK: %[[C1:.+]] = arith.constant 1 : index 809// CHECK: %[[C0:.+]] = arith.constant 0 : index 810// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32> 811// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32> 812// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index 813// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index 814// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32> 815// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32> 816// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32> 817// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index 818// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index 819// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32> 820// CHECK: %[[T3:.+]] = linalg.generic 821// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]] 822// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] 823// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>) 824// CHECK-SAME: outs(%[[T2]] : tensor<?x7x?x8xf32>) 825// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]] 826// CHECK-SAME: [0, 1], [2, 3] 827// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32> 828// CHECK: return %[[T4]] 829 830// ----- 831 832func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> { 833 %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> 834 %cst = arith.constant 0 : i32 835 %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] { 836 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index): 837 tensor.yield %cst : i32 838 } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32> 839 return %padded_0 : tensor<8x12x17x336x14xi32> 840} 841// CHECK: func @fuse_by_expanding_pad( 842// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>) 843// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] 844// CHECK-SAME: low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] 845// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32> 846// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] 847// CHECK-SAME: : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32> 848// CHECK: return %[[COLLAPSE]] 849 850// ----- 851 852func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> { 853 %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> 854 %cst = arith.constant 0 : i32 855 %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] { 856 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index): 857 tensor.yield %cst : i32 858 } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32> 859 return %padded_0 : tensor<8x12x17x339x14xi32> 860} 861// CHECK: func @no_fuse_by_expanding_pad( 862// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>) 863// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] 864// CHECK-SAME: : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> 865// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] 866// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] 867// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32> 868// CHECK: return %[[PAD]] 869 870// ----- 871 872func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?xi32> { 873 %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32> 874 %cst = arith.constant 0 : i32 875 %padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] { 876 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): 877 tensor.yield %cst : i32 878 } : tensor<?x?x?x?xi32> to tensor<?x?x?x?xi32> 879 return %padded_0 : tensor<?x?x?x?xi32> 880} 881// CHECK: func @fuse_by_expanding_dynamic_pad( 882// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xi32> 883// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index 884// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] 885// CHECK-SAME: low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0] 886// CHECK: tensor<?x?x?x?x?x?xi32> to tensor<?x?x?x?x?x?xi32> 887// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]] 888// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32> 889// CHECK: return %[[COLLAPSE]] 890