1// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s 2 3 4// CHECK-LABEL: expand_shape_identity_fold 5// CHECK-NEXT: return 6func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> { 7 %0 = tensor.expand_shape %arg0 [[0]] output_shape [5] : tensor<5xf32> into tensor<5xf32> 8 return %0 : tensor<5xf32> 9} 10 11// ----- 12 13// CHECK-LABEL: expand_shape_rank0_identity_fold 14// CHECK-NEXT: return 15func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> { 16 %0 = tensor.expand_shape %arg0 [] output_shape [] : tensor<f32> into tensor<f32> 17 return %0 : tensor<f32> 18} 19 20// ----- 21 22// CHECK-LABEL: collapse_shape_identity_fold 23// CHECK-NEXT: return 24func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> { 25 %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32> 26 return %0 : tensor<5x4xf32> 27} 28 29// ----- 30 31// CHECK-LABEL: collapse_shape_rank0_identity_fold 32// CHECK-NEXT: return 33func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> { 34 %0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32> 35 return %0 : tensor<f32> 36} 37 38// ----- 39 40// CHECK-LABEL: @tensor_bitcast_chain_ok 41// CHECK-SAME: %[[IN:.*]]: tensor<2xi32> 42func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> { 43 // CHECK-NEXT: %[[RES:.*]] = tensor.bitcast %[[IN]] : tensor<2xi32> to tensor<2xf32> 44 %0 = tensor.bitcast %input : tensor<2xi32> to tensor<2xui32> 45 %1 = tensor.bitcast %0 : tensor<2xui32> to tensor<2xf32> 46 // CHECK-NEXT: return %[[RES]] 47 return %1 : tensor<2xf32> 48} 49 50// ----- 51 52// CHECK-LABEL: @tensor_bitcast_chain_nop 53// CHECK-SAME: %[[IN:.*]]: tensor<4xi32> 54func.func @tensor_bitcast_chain_nop(%input: tensor<4xi32>) -> tensor<4xi32> { 55 %0 = tensor.bitcast %input : tensor<4xi32> to tensor<4xui32> 56 %1 = tensor.bitcast %0 : tensor<4xui32> to tensor<4xi32> 57 // CHECK-NEXT: return %[[IN]] 58 return %1 : tensor<4xi32> 59} 60 61// ----- 62 63// Checks that NOP casts are removed. 64// CHECK-LABEL: cast_values 65func.func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> { 66 // NOP cast 67 %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32> 68 // CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32> 69 %2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32> 70 // NOP cast 71 %4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32> 72 // CHECK-NEXT: return %[[RET]] : tensor<2xi32> 73 return %4 : tensor<2xi32> 74} 75 76// ----- 77 78// CHECK-LABEL: @tensor.cast_chain_ok 79// CHECK-SAME: %[[IN:.*]]: tensor<*xi32> 80func.func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> { 81 // CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32> 82 %0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32> 83 %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32> 84 // CHECK-NEXT: return %[[RES]] 85 return %1 : tensor<4x8xi32> 86} 87 88// ----- 89 90// CHECK-LABEL: @tensor.cast_chain_regain 91// CHECK-SAME: %[[IN:.*]]: tensor<4xi32> 92func.func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> { 93 %0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32> 94 %1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32> 95 // CHECK-NEXT: return %[[IN]] 96 return %1 : tensor<4xi32> 97} 98 99// ----- 100 101// CHECK-LABEL: @tensor.cast_chain_keep 102// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32> 103func.func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> { 104 // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]] 105 %0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32> 106 // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]] 107 %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32> 108 // CHECK-NEXT: return %[[C2]] 109 return %1 : tensor<?x8xi32> 110} 111 112// ----- 113 114// CHECK-LABEL: @tensor.cast_chain_invalid 115// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32> 116func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> { 117 // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]] 118 %0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32> 119 // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]] 120 %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32> 121 // CHECK-NEXT: return %[[C2]] 122 return %1 : tensor<8x4xi32> 123} 124 125// ----- 126 127// CHECK-LABEL: fold_concat 128// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32> 129func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) { 130 %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32> 131 // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32> 132 %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32> 133 // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32> 134 return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32> 135} 136 137// ----- 138 139// CHECK-LABEL: func @fold_extract 140func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) { 141 %const_0 = arith.constant 0 : index 142 %const_1 = arith.constant 1 : index 143 %const_3 = arith.constant 3 : index 144 // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32 145 // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16 146 // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16 147 148 // Fold an extract into a splat. 149 // CHECK-DAG: [[C4:%.+]] = arith.constant 4.{{0*}}e+00 : f32 150 %0 = arith.constant dense<4.0> : tensor<4xf32> 151 %ext_1 = tensor.extract %0[%arg0] : tensor<4xf32> 152 153 // Fold an extract into a sparse with a sparse index. 154 %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : tensor<4x4x4xf16> 155 %ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16> 156 157 // Fold an extract into a sparse with a non sparse index. 158 %2 = arith.constant sparse<[[1, 1, 1]], [-2.0]> : tensor<2x2x2xf16> 159 %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16> 160 161 // Fold an extract into a dense tensor. 162 %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32> 163 %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32> 164 165 // Fold an extract into a complex constant. 166 // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32> 167 %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>> 168 %ext_5 = tensor.extract %4[] : tensor<complex<f32>> 169 170 // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]] 171 return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32> 172} 173 174// ----- 175 176// Ensure extract dense resource elements not crash. 177 178// CHECK-LABEL: func @extract_dense_resource_nofold 179func.func @extract_dense_resource_nofold() -> i64 { 180 // CHECK: %[[EXT:.+]] = tensor.extract 181 // CHECK-NEXT: return %[[EXT]] 182 %c0 = arith.constant 0 : index 183 %cst = arith.constant dense_resource<__elided__> : tensor<1xi64> 184 %extracted = tensor.extract %cst[%c0] : tensor<1xi64> 185 return %extracted : i64 186} 187 188// ----- 189 190// CHECK-LABEL: func @fold_insert 191func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) { 192 // Fold an insert into a splat. 193 // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32> 194 %0 = arith.constant dense<4.0> : tensor<4xf32> 195 %1 = arith.constant 4.0 : f32 196 %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32> 197 // CHECK-NEXT: return %[[C4]] 198 return %ins_1 : tensor<4xf32> 199} 200 201// ----- 202 203// CHECK-LABEL: func @extract_from_tensor.cast 204// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32> 205func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 { 206 // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index 207 %c0 = arith.constant 0 : index 208 // CHECK-NOT: tensor.cast 209 %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32> 210 // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]] 211 %result = tensor.extract %casted[%c0] : tensor<?xf32> 212 return %result : f32 213} 214 215// ----- 216 217// CHECK-LABEL: func @extract_from_tensor.from_elements 218func.func @extract_from_tensor.from_elements(%element : index) -> index { 219 // CHECK-SAME: ([[ARG:%.*]]: index) 220 %c0 = arith.constant 0 : index 221 %tensor = tensor.from_elements %element : tensor<1xindex> 222 %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex> 223 // CHECK: [[ARG]] : index 224 return %extracted_element : index 225} 226 227// ----- 228 229// CHECK-LABEL: func @extract_from_tensor.from_elements_0d 230func.func @extract_from_tensor.from_elements_0d(%element : index) -> index { 231 // CHECK-SAME: ([[ARG:%.*]]: index) 232 %c0 = arith.constant 0 : index 233 %tensor = tensor.from_elements %element : tensor<index> 234 %extracted_element = tensor.extract %tensor[] : tensor<index> 235 // CHECK: [[ARG]] : index 236 return %extracted_element : index 237} 238 239// ----- 240 241// CHECK-LABEL: func @extract_from_tensor.from_elements_3d 242func.func @extract_from_tensor.from_elements_3d() 243 -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) { 244 %f0 = arith.constant 0.0 : f32 245 %f1 = arith.constant 1.0 : f32 246 %f2 = arith.constant 2.0 : f32 247 %f3 = arith.constant 3.0 : f32 248 %f4 = arith.constant 4.0 : f32 249 %f5 = arith.constant 5.0 : f32 250 %f6 = arith.constant 6.0 : f32 251 %f7 = arith.constant 7.0 : f32 252 %f8 = arith.constant 8.0 : f32 253 %f9 = arith.constant 9.0 : f32 254 %f10 = arith.constant 10.0 : f32 255 %f11 = arith.constant 11.0 : f32 256 257 %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11 258 : tensor<3x2x2xf32> 259 %c0 = arith.constant 0 : index 260 %c1 = arith.constant 1 : index 261 %c2 = arith.constant 2 : index 262 263 %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32> 264 %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32> 265 %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32> 266 %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32> 267 %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32> 268 %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32> 269 %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32> 270 %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32> 271 %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32> 272 %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32> 273 %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32> 274 %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32> 275 return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11 276 : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32 277} 278// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0 279// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 280// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 281// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 282// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0 283// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0 284// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0 285// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0 286// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0 287// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0 288// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 289// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 290 291// CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]], 292// CHECK-SAME: %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]] 293 294// ----- 295 296// CHECK-LABEL: func @extract_from_tensor.from_elements_variable_3d 297// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: f32 298// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: f32 299// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: f32 300// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: f32 301// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: f32 302// CHECK-SAME: %[[ARG_5:[a-zA-Z0-9_]+]]: f32 303// CHECK-SAME: %[[ARG_6:[a-zA-Z0-9_]+]]: f32 304// CHECK-SAME: %[[ARG_7:[a-zA-Z0-9_]+]]: f32 305// CHECK-SAME: %[[ARG_8:[a-zA-Z0-9_]+]]: f32 306// CHECK-SAME: %[[ARG_9:[a-zA-Z0-9_]+]]: f32 307// CHECK-SAME: %[[ARG_10:[a-zA-Z0-9_]+]]: f32 308// CHECK-SAME: %[[ARG_11:[a-zA-Z0-9_]+]]: f32 309func.func @extract_from_tensor.from_elements_variable_3d( 310 %f0: f32, %f1: f32, %f2: f32, %f3: f32, %f4: f32, %f5: f32, 311 %f6: f32, %f7: f32, %f8: f32, %f9: f32, %f10: f32, %f11: f32) 312 -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) { 313 314 %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11 315 : tensor<3x2x2xf32> 316 %c0 = arith.constant 0 : index 317 %c1 = arith.constant 1 : index 318 %c2 = arith.constant 2 : index 319 320 %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32> 321 %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32> 322 %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32> 323 %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32> 324 %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32> 325 %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32> 326 %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32> 327 %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32> 328 %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32> 329 %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32> 330 %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32> 331 %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32> 332 return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11 333 : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32 334} 335// CHECK: return %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], 336// CHECK-SAME: %[[ARG_6]], %[[ARG_7]], %[[ARG_8]], %[[ARG_9]], %[[ARG_10]], %[[ARG_11]] 337 338// ----- 339 340// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> { 341// CHECK-NEXT: %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>> 342// CHECK-NEXT: return %cst : tensor<3xcomplex<i32>> 343func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> { 344 %c1 = arith.constant dense<(1, 2)> : tensor<complex<i32>> 345 %complex1 = tensor.extract %c1[] : tensor<complex<i32>> 346 %c2 = arith.constant dense<(3, 2)> : tensor<complex<i32>> 347 %complex2 = tensor.extract %c2[] : tensor<complex<i32>> 348 %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<i32>> 349 return %tensor : tensor<3xcomplex<i32>> 350} 351 352// ----- 353 354// CHECK-LABEL: func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> { 355// CHECK-NEXT: %cst = arith.constant dense<[(1.200000e+00,2.300000e+00), (3.200000e+00,2.100000e+00), (1.200000e+00,2.300000e+00)]> : tensor<3xcomplex<f32>> 356// CHECK-NEXT: return %cst : tensor<3xcomplex<f32>> 357func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> { 358 %c1 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>> 359 %complex1 = tensor.extract %c1[] : tensor<complex<f32>> 360 %c2 = arith.constant dense<(3.2, 2.1)> : tensor<complex<f32>> 361 %complex2 = tensor.extract %c2[] : tensor<complex<f32>> 362 %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<f32>> 363 return %tensor : tensor<3xcomplex<f32>> 364} 365 366// ----- 367 368// Ensure the optimization doesn't segfault from bad constants 369// CHECK-LABEL: func @extract_negative_from_tensor.from_elements 370func.func @extract_negative_from_tensor.from_elements(%element : index) -> index { 371 // CHECK-SAME: ([[ARG:%.*]]: index) 372 %c-1 = arith.constant -1 : index 373 %tensor = tensor.from_elements %element : tensor<1xindex> 374 %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex> 375 // CHECK: tensor.from_elements 376 // CHECK: %[[RESULT:.*]] = tensor.extract 377 // CHECK: return %[[RESULT]] 378 return %extracted_element : index 379} 380 381// ----- 382 383// Ensure the optimization doesn't segfault from bad constants 384// CHECK-LABEL: func @extract_oob_from_tensor.from_elements 385func.func @extract_oob_from_tensor.from_elements(%element : index) -> index { 386 // CHECK-SAME: ([[ARG:%.*]]: index) 387 %c1 = arith.constant 1 : index 388 %tensor = tensor.from_elements %element : tensor<1xindex> 389 %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex> 390 // CHECK: tensor.from_elements 391 // CHECK: %[[RESULT:.*]] = tensor.extract 392 // CHECK: return %[[RESULT]] 393 return %extracted_element : index 394} 395 396// ----- 397 398// Ensure the optimization doesn't segfault from bad constants 399// CHECK-LABEL: func @extract_oob_from_tensor.from_elements 400func.func @extract_oob_from_tensor.from_elements(%element : index) -> index { 401 // CHECK-SAME: ([[ARG:%.*]]: index) 402 %c2 = arith.constant 2 : index 403 %tensor = tensor.from_elements %element : tensor<1xindex> 404 %extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex> 405 // CHECK: tensor.from_elements 406 // CHECK: %[[RESULT:.*]] = tensor.extract 407 // CHECK: return %[[RESULT]] 408 return %extracted_element : index 409} 410 411// ----- 412 413// CHECK-LABEL: func @extract_from_tensor.generate 414// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> 415func.func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index { 416 %size = tensor.rank %tensor : tensor<*xf32> 417 // CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]] 418 %0 = tensor.generate %size { 419 ^bb0(%arg0: index): 420 %1 = tensor.dim %tensor, %arg0 : tensor<*xf32> 421 tensor.yield %1 : index 422 } : tensor<?xindex> 423 %1 = tensor.extract %0[%idx] : tensor<?xindex> 424 // CHECK-NEXT: return %[[RES]] 425 return %1 : index 426} 427 428// ----- 429 430// CHECK-LABEL: func @extract_from_tensor.generate_2d 431// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> 432func.func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index { 433 %size = tensor.rank %tensor : tensor<*xf32> 434 // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]] 435 // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]] 436 // CHECK-NEXT: %[[RES:.*]] = arith.addi %[[DIM0]], %[[DIM1]] 437 %0 = tensor.generate %size, %size { 438 ^bb0(%arg0: index, %arg1: index): 439 %1 = tensor.dim %tensor, %arg0 : tensor<*xf32> 440 %2 = tensor.dim %tensor, %arg1 : tensor<*xf32> 441 %3 = arith.addi %1, %2 : index 442 tensor.yield %3 : index 443 } : tensor<?x?xindex> 444 %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex> 445 // CHECK-NEXT: return %[[RES]] 446 return %4 : index 447} 448 449// ----- 450 451// CHECK-LABEL: func @extract_from_tensor.generate_sideeffects 452// CHECK-SAME: %[[IDX:.*]]: index 453func.func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index { 454 %size = tensor.rank %tensor : tensor<*xf32> 455 // CHECK: %[[DTENSOR:.*]] = tensor.generate 456 %0 = tensor.generate %size { 457 ^bb0(%arg0: index): 458 %1 = tensor.dim %tensor, %arg0 : tensor<*xf32> 459 memref.store %1, %mem[%arg0] : memref<?xindex> 460 tensor.yield %1 : index 461 } : tensor<?xindex> 462 // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]] 463 %1 = tensor.extract %0[%idx] : tensor<?xindex> 464 // CHECK-NEXT: return %[[RES]] 465 return %1 : index 466} 467 468// ----- 469 470// CHECK-LABEL: @static_tensor.generate 471// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index) 472func.func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> { 473 %c5 = arith.constant 5 : index 474 // CHECK: tensor.generate %[[SIZE1]], %[[SIZE4]] 475 %0 = tensor.generate %size1, %c5, %size4 { 476 ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index): 477 %1 = arith.constant 32 : index 478 tensor.yield %1 : index 479 // CHECK: : tensor<3x?x5x7x?xindex> 480 } : tensor<3x?x?x7x?xindex> 481 // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> 482 return %0 : tensor<3x?x?x7x?xindex> 483} 484 485// ----- 486 487// CHECK-LABEL: @from_elements.constant 488func.func @from_elements.constant() -> tensor<3xindex> { 489 // CHECK: %[[CST:.*]] = arith.constant dense<[1, 2, 1]> : tensor<3xindex> 490 // CHECK: return %[[CST]] 491 %c1 = arith.constant 1 : index 492 %c2 = arith.constant 2 : index 493 %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex> 494 return %tensor : tensor<3xindex> 495} 496 497// ----- 498 499func.func @slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, 500 %arg2 : index) -> tensor<?x?x?xf32> 501{ 502 %c0 = arith.constant 0 : index 503 %c1 = arith.constant 1 : index 504 %c4 = arith.constant 4 : index 505 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32> 506 return %0 : tensor<?x?x?xf32> 507} 508// CHECK-LABEL: func @slice_canonicalize 509// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> 510// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] 511// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] 512// CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32> 513// CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]] 514// CHECK: return %[[RESULT]] 515 516// ----- 517 518func.func @rank_reducing_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, 519 %arg2 : index) -> tensor<?x?xf32> 520{ 521 %c0 = arith.constant 0 : index 522 %c1 = arith.constant 1 : index 523 %c4 = arith.constant 4 : index 524 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32> 525 return %0 : tensor<?x?xf32> 526} 527// CHECK-LABEL: func @rank_reducing_slice_canonicalize 528// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> 529// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] 530// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] 531// CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32> 532// CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]] 533// CHECK: return %[[RESULT]] 534 535// ----- 536 537// CHECK-LABEL: func @trivial_slice 538// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> 539// CHECK-NOT: tensor.extract_slice 540// CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8> 541func.func @trivial_slice(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { 542 %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8> 543 return %0 : tensor<4x6x16x32xi8> 544} 545 546// ----- 547 548// CHECK-LABEL: func @trivial_insert_slice 549// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> 550// CHECK-NOT: tensor.extract_slice 551// CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8> 552func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { 553 %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8> 554 return %0 : tensor<4x6x16x32xi8> 555} 556 557// ----- 558 559// CHECK-LABEL: func @empty_insert_slice 560// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8> 561// CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8> 562// CHECK-NOT: tensor.extract_slice 563// CHECK: return %[[ARG1]] : tensor<3x3xi8> 564func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> { 565 %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8> 566 return %0 : tensor<3x3xi8> 567} 568 569// ----- 570 571// CHECK-LABEL: func @rank_reducing_tensor_of_cast 572// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> 573// CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8> 574// Tensor cast is moved after slice and then gets canonicalized away. 575// CHECK-NOT: tensor.cast 576// CHECK: return %[[S]] : tensor<16x32xi8> 577func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> { 578 %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8> 579 %1 = tensor.extract_slice %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8> 580 return %1 : tensor<16x32xi8> 581} 582 583// ----- 584 585// CHECK-LABEL: func @rank_reducing_insert_slice_of_cast 586// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8> 587// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> 588// CHECK: %[[S:.+]] = tensor.insert_slice %[[A]] into %[[B]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8> 589// Tensor cast is folded away. 590// CHECK-NOT: tensor.cast 591// CHECK: return %[[S]] : tensor<4x6x16x32xi8> 592func.func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { 593 %c0 = arith.constant 0: index 594 %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8> 595 %sz = tensor.dim %cast, %c0: tensor<?x32xi8> 596 %res = tensor.insert_slice %cast into %b[0, 1, 0, 0] [1, 1, %sz, 32] [1, 1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8> 597 return %res : tensor<4x6x16x32xi8> 598} 599 600// ----- 601 602func.func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, 603 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 604{ 605 %c0 = arith.constant 0 : index 606 %c1 = arith.constant 1 : index 607 %c4 = arith.constant 4 : index 608 %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32> 609 return %0 : tensor<?x?x?xf32> 610} 611// CHECK-LABEL: func @insert_slice_canonicalize 612// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 613// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32> 614// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]] 615// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] 616// CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32> 617// CHECK: return %[[RESULT]] 618 619// ----- 620 621// Do not insert a cast for the following example. The new source type wouldn't be "more static" than the old one. 622func.func @insert_slice_canonicalize_encoding(%arg0 : tensor<2x2xf32, "foo">, 623 %arg1 : tensor<4x4xf32, "foo">) -> tensor<4x4xf32, "foo"> 624{ 625 %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [2, 2] [1, 1] : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo"> 626 return %0 : tensor<4x4xf32, "foo"> 627} 628// CHECK-LABEL: func @insert_slice_canonicalize_encoding 629// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x2xf32, "foo"> 630// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<4x4xf32, "foo"> 631// CHECK-NOT: tensor.cast 632// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[ARG1]] 633// CHECK-SAME: [0, 0] [2, 2] [1, 1] 634// CHECK-SAME: : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo"> 635// CHECK: return %[[RESULT]] 636 637// ----- 638 639func.func @slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, 640 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 641{ 642 %c0 = arith.constant 0 : index 643 %c1 = arith.constant 1 : index 644 %c4 = arith.constant 4 : index 645 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32> 646 %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32> 647 return %1 : tensor<?x?x?xf32> 648} 649// CHECK-LABEL: func @slice_to_insert_slice_canonicalize 650// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 651// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 652// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] 653// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1] 654// CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32> 655// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]] 656// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] 657// CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32> 658// CHECK: return %[[RESULT]] 659 660// ----- 661 662func.func @rank_reducing_insert_slice_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : index, 663 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 664{ 665 %c0 = arith.constant 0 : index 666 %c1 = arith.constant 1 : index 667 %c4 = arith.constant 4 : index 668 %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32> 669 return %0 : tensor<?x?x?xf32> 670} 671// CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize 672// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> 673// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32> 674// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]] 675// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] 676// CHECK-SAME: : tensor<4x?xf32> into tensor<?x?x?xf32> 677// CHECK: return %[[RESULT]] 678 679// ----- 680 681func.func @rank_reducing_slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, 682 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 683{ 684 %c0 = arith.constant 0 : index 685 %c1 = arith.constant 1 : index 686 %c4 = arith.constant 4 : index 687 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32> 688 %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32> 689 return %1 : tensor<?x?x?xf32> 690} 691// CHECK-LABEL: func @rank_reducing_slice_to_insert_slice_canonicalize 692// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 693// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 694// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] 695// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] 696// CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32> 697// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ARG3]] 698// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] 699// CHECK-SAME: : tensor<4x?xf32> into tensor<?x?x?xf32> 700// CHECK: return %[[RESULT]] 701 702// ----- 703 704func.func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>, 705 %arg2 : index, %arg3 : index) -> tensor<?x?xi32> { 706 %c0 = arith.constant 0 : index 707 %c1 = arith.constant 1 : index 708 %c2 = arith.constant 2 : index 709 %c8 = arith.constant 8 : index 710 %0 = tensor.dim %arg0, %c1 : tensor<2x?xi32> 711 %1 = tensor.extract %arg1[] : tensor<i32> 712 %2 = tensor.generate %arg2, %c8 { 713 ^bb0(%arg4: index, %arg5: index): 714 tensor.yield %1 : i32 715 } : tensor<?x?xi32> 716 %3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor<?x?xi32> 717 return %3 : tensor<?x?xi32> 718} 719// CHECK-LABEL: func @insert_slice_propagate_dest_cast 720// CHECK: %[[UPDATED:.+]] = tensor.insert_slice %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1] 721// CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32> 722// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]] 723// CHECK: return %[[CAST]] 724 725// ----- 726 727func.func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> { 728 %c9 = arith.constant 9 : index 729 %c3 = arith.constant 3 : index 730 %2 = tensor.extract %arg1[] : tensor<i32> 731 %4 = tensor.generate %c3, %c9 { 732 ^bb0(%arg2: index, %arg3: index): 733 tensor.yield %2 : i32 734 } : tensor<?x?xi32> 735 %5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32> 736 %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32> 737 return %6 : tensor<3x9xi32> 738} 739// CHECK-LABEL: func @insert_slice_output_dest_canonicalize 740// CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32> 741// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32> 742// CHECK: %[[PAD:.+]] = tensor.extract %[[ARG1]] 743// CHECK: %[[GENERATE:.+]] = tensor.generate 744// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[GENERATE]] 745// CHECK: return %[[RESULT]] 746 747// ----- 748 749// Test case: Folding of tensor.dim(tensor.generate %idx) -> %idx 750// CHECK-LABEL: func @dim_of_tensor.generate( 751// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index 752// CHECK-NOT: tensor.dim 753// CHECK: return %[[IDX1]] : index 754func.func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index { 755 %c3 = arith.constant 3 : index 756 %0 = tensor.generate %arg0, %arg1 { 757 ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): 758 tensor.yield %c3 : index 759 } : tensor<2x?x4x?x5xindex> 760 %1 = tensor.dim %0, %c3 : tensor<2x?x4x?x5xindex> 761 return %1 : index 762} 763 764// ----- 765 766// Test case: Folding tensor.dim(tensor.cast %0, %idx) -> tensor.dim %0, %idx 767// CHECK-LABEL: func @fold_dim_of_tensor.cast 768// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32> 769// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 770// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index 771// CHECK: %[[T0:.+]] = tensor.dim %[[ARG0]], %[[C1]] 772// CHECK-NEXT: return %[[C4]], %[[T0]] 773func.func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) { 774 %c0 = arith.constant 0 : index 775 %c1 = arith.constant 1 : index 776 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32> 777 %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 778 %2 = tensor.dim %0, %c1 : tensor<?x?xf32> 779 return %1, %2: index, index 780} 781 782// ----- 783 784// CHECK-LABEL: func @insert_slice_cast 785func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> { 786 // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32> 787 %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32> 788 // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]] 789 // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}] 790 // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32> 791 %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32> 792 // CHECK: return %[[RES]] : tensor<?x?xf32> 793 return %1 : tensor<?x?xf32> 794} 795 796// ----- 797 798// CHECK-LABEL: func @insert_slice_cast_no_fold 799func.func @insert_slice_cast_no_fold(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> { 800 %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x5xf32> 801 // CHECK: %[[CAST:.*]] = tensor.cast 802 // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]] 803 // CHECK-SAME: [{{.*}}, {{.*}}] [{{.*}}, 5] [{{.*}}, {{.*}}] 804 // CHECK-SAME: : tensor<?x5xf32> into tensor<?x?xf32> 805 %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, 5] [%arg6, %arg7] : tensor<?x5xf32> into tensor<?x?xf32> 806 // CHECK: return %[[RES]] : tensor<?x?xf32> 807 return %1 : tensor<?x?xf32> 808} 809 810// ----- 811 812// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src( 813// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32> 814// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32> 815// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32> 816// CHECK: return %[[r]] 817func.func @insert_tensor_cast_on_insert_slice_src( 818 %arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> { 819 %c64 = arith.constant 64: index 820 %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1] 821 : tensor<?x5x?xf32> into tensor<?x?x?xf32> 822 return %r : tensor<?x?x?xf32> 823} 824 825// ----- 826 827// CHECK-LABEL: func @fold_extract_insert 828// CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32> 829func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) { 830 %c0 = arith.constant 0: index 831 %c1 = arith.constant 1: index 832 %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32> 833 %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<?x?x?xf32> to tensor<4x?x8xf32> 834 // CHECK: return %[[SLICE]] 835 return %1 : tensor<4x?x8xf32> 836} 837 838// ----- 839 840// CHECK-LABEL: func @fold_gather_constant_splat 841// CHECK-NOT: tensor.gather 842// CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32> 843func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> { 844 %cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32> 845 %0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) : 846 (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32> 847 return %0 : tensor<1x2x 1x1x1xf32> 848} 849 850// ----- 851 852// CHECK-LABEL: func @fold_reshape_constant_splat 853// CHECK-NOT: tensor.reshape 854// CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32> 855func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> { 856 %cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32> 857 %0 = tensor.reshape %cst(%shape) 858 : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32> 859 return %0 : tensor<4xf32> 860} 861 862// ----- 863 864// CHECK-LABEL: func @fold_reshape_chain 865// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32> 866// CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex> 867// CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex> 868// CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex> 869// CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]]) 870// CHECK: return %[[RESULT]] 871func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> { 872 %0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> 873 %1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> 874 %2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> 875 return %2 : tensor<*xf32> 876} 877 878// ----- 879 880// CHECK-LABEL: func @fold_reshape_1d 881// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?xf32> 882// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<1xindex> 883// CHECK: return %[[INPUT]] 884func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> { 885 %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> 886 return %0 : tensor<?xf32> 887} 888 889// ----- 890 891// CHECK-LABEL: func @fold_extract_constant_splat 892// CHECK-NOT: tensor.extract_slice 893// CHECK: arith.constant dense<42> : tensor<4x4xi32> 894func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) { 895 %cst = arith.constant dense<42> : tensor<1024x1024xi32> 896 %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32> 897 return %1 : tensor<4x4xi32> 898} 899 900// ----- 901 902// CHECK-LABEL: func @fold_pack_constant_splat 903// CHECK-NOT: tensor.pack 904// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> 905func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { 906 %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32> 907 %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] 908 inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32> 909 return %0 : tensor<8x16x8x32xf32> 910} 911 912// ----- 913 914// CHECK-LABEL: func @fold_padding_value_pack_constant_splat 915// CHECK-NOT: tensor.pack 916// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> 917func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { 918 %pad = arith.constant 1.000000e-01 : f32 919 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> 920 %0 = tensor.pack %cst 921 padding_value(%pad : f32) 922 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] 923 inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> 924 return %0 : tensor<8x16x8x32xf32> 925} 926 927 928// ----- 929 930// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat 931// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32> 932// CHECK: tensor.pack 933func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { 934 %pad = arith.constant 0.0 : f32 935 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> 936 %0 = tensor.pack %cst 937 padding_value(%pad : f32) 938 outer_dims_perm = [1, 0] 939 inner_dims_pos = [0, 1] 940 inner_tiles = [8, 32] 941 into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> 942 return %0 : tensor<8x16x8x32xf32> 943} 944 945// ----- 946 947func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> { 948 %cst = arith.constant 0.000000e+00 : f32 949 %0 = tensor.empty() : tensor<31250x1200x16x1xf32> 950 %pack = tensor.pack %arg0 951 padding_value(%cst : f32) 952 outer_dims_perm = [1, 0] 953 inner_dims_pos = [1, 0] 954 inner_tiles = [16, 1] 955 into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32> 956 return %pack : tensor<31250x1200x16x1xf32> 957} 958// CHECK-LABEL: func @fold_padding_value_pack 959// CHECK-NOT: padding_value 960 961// ----- 962 963func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> { 964 %cst = arith.constant 0.000000e+00 : f32 965 %pack = tensor.pack %src 966 padding_value(%cst : f32) 967 outer_dims_perm = [2, 1, 3, 0] 968 inner_dims_pos = [2] 969 inner_tiles = [16] 970 into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32> 971 return %pack : tensor<10x20x30x40x16xf32> 972} 973// CHECK-LABEL: func.func @infer_src_shape_pack 974// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] 975// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] 976// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32> 977// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]] 978// CHECK: return %[[PACK]] 979 980// ----- 981 982func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> { 983 %cst = arith.constant 0.000000e+00 : f32 984 %pack = tensor.pack %src 985 padding_value(%cst : f32) 986 outer_dims_perm = [2, 1, 3, 0] 987 inner_dims_pos = [2] 988 inner_tiles = [16] 989 into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32> 990 return %pack : tensor<?x?x?x?x16xf32> 991} 992// CHECK-LABEL: func.func @infer_dest_shape_pack 993// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] 994// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] 995// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32> 996// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]] 997// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32> 998// CHECK: return %[[CAST_PACK]] 999 1000// ----- 1001 1002func.func @no_infer_pack_shape(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> { 1003 %cst = arith.constant 0.000000e+00 : f32 1004 %0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32> 1005 %pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32> 1006 return %pack : tensor<32x7x?x16x1xf32> 1007} 1008// CHECK-LABEL: func.func @no_infer_pack_shape 1009// CHECK-NOT: tensor.cast 1010 1011// ----- 1012 1013func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> { 1014 %cst = arith.constant 0.000000e+00 : f32 1015 %0 = tensor.empty() : tensor<31250x1200x16x1xf32> 1016 %pack = tensor.pack %arg0 1017 padding_value(%cst : f32) 1018 outer_dims_perm = [1, 0] 1019 inner_dims_pos = [1, 0] 1020 inner_tiles = [16, 1] 1021 into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32> 1022 return %pack : tensor<31250x1200x16x1xf32> 1023} 1024// CHECK-LABEL: func @fold_padding_value_pack_negative1 1025// CHECK: tensor.pack 1026// CHECK-SAME: padding_value 1027 1028// ----- 1029 1030func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> { 1031 %cst = arith.constant 0.000000e+00 : f32 1032 %pack = tensor.pack %arg0 1033 padding_value(%cst : f32) 1034 outer_dims_perm = [1, 0] 1035 inner_dims_pos = [1, 0] 1036 inner_tiles = [16, 1] 1037 into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32> 1038 return %pack : tensor<?x1200x16x1xf32> 1039} 1040// CHECK-LABEL: func @fold_padding_value_pack_negative2 1041// CHECK: tensor.pack 1042// CHECK-SAME: padding_value 1043 1044// ----- 1045 1046func.func @fold_padding_value_pack_negative3(%arg0: tensor<1200x500000xf32>, %arg1: tensor<?x1200x?x1xf32>, %tile : index) -> tensor<?x1200x?x1xf32> { 1047 %cst = arith.constant 0.000000e+00 : f32 1048 %pack = tensor.pack %arg0 1049 padding_value(%cst : f32) 1050 outer_dims_perm = [1, 0] 1051 inner_dims_pos = [1, 0] 1052 inner_tiles = [%tile, 1] 1053 into %arg1 : tensor<1200x500000xf32> -> tensor<?x1200x?x1xf32> 1054 return %pack : tensor<?x1200x?x1xf32> 1055} 1056// CHECK-LABEL: func @fold_padding_value_pack_negative3 1057// CHECK: tensor.pack 1058// CHECK-SAME: padding_value 1059 1060// ----- 1061 1062// CHECK-LABEL: func @fold_unpack_constant_splat 1063// CHECK-NOT: tensor.unpack 1064// CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32> 1065func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> { 1066 %cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32> 1067 %0 = tensor.unpack %cst inner_dims_pos = [0, 1] 1068 inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32> 1069 return %0 : tensor<128x256xf32> 1070} 1071 1072// ----- 1073 1074func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { 1075 %unpack = tensor.unpack %src 1076 outer_dims_perm = [2, 1, 3, 0] 1077 inner_dims_pos = [2] 1078 inner_tiles = [16] 1079 into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32> 1080 return %unpack : tensor<?x?x?x?xf32> 1081} 1082// CHECK-LABEL: func.func @infer_dest_shape_unpack 1083// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] 1084// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] 1085// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32> 1086// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]] 1087// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<40x20x?x30xf32> to tensor<?x?x?x?xf32> 1088// CHECK: return %[[CAST_UNPACK]] 1089 1090// ----- 1091 1092func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> { 1093 %unpack = tensor.unpack %src 1094 outer_dims_perm = [2, 1, 3, 0] 1095 inner_dims_pos = [2] 1096 inner_tiles = [16] 1097 into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32> 1098 return %unpack : tensor<30x20x?x10xf32> 1099} 1100// CHECK-LABEL: func.func @infer_src_shape_unpack 1101// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] 1102// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] 1103// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32> 1104// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]] 1105// CHECK: return %[[UNPACK]] 1106 1107// ----- 1108 1109func.func @no_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor<?x32x100xf32> { 1110 %cst = arith.constant 0.000000e+00 : f32 1111 %0 = tensor.empty(%arg2) : tensor<?x32x100xf32> 1112 %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<32x7x?x16x1xf32> -> tensor<?x32x100xf32> 1113 return %unpack : tensor<?x32x100xf32> 1114} 1115// CHECK-LABEL: func.func @no_infer_unpack_shape 1116// CHECK-NOT: tensor.cast 1117 1118// ----- 1119 1120 1121// CHECK-LABEL: func @fold_overlapping_insert 1122// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32> 1123func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) { 1124 %c0 = arith.constant 0: index 1125 %c1 = arith.constant 1: index 1126 %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32> 1127 // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]] 1128 %1 = tensor.insert_slice %slice2 into %0[0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32> 1129 // CHECK: return %[[INSERT]] 1130 return %1 : tensor<?x?x?xf32> 1131} 1132 1133// ----- 1134 1135func.func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) 1136 -> tensor<?x6x4x?x5xf32> { 1137 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] 1138 : tensor<?x?xf32> into tensor<?x4x?xf32> 1139 %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32> 1140 return %1 : tensor<?x6x4x?x5xf32> 1141} 1142// CHECK-LABEL: compose_expand_of_expand 1143// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] 1144// CHECK-NOT: tensor.expand_shape 1145 1146// ----- 1147 1148func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>) 1149 -> tensor<1x1x1xf32> { 1150 %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32> 1151 %1 = tensor.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1] 1152 : tensor<1xf32> into tensor<1x1x1xf32> 1153 return %1 : tensor<1x1x1xf32> 1154} 1155// CHECK-LABEL: compose_expand_of_expand_of_zero_dim 1156// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1, 1] 1157// CHECK-SAME: tensor<f32> into tensor<1x1x1xf32> 1158 1159// ----- 1160 1161// CHECK-LABEL: func.func @collapse_of_cast( 1162// CHECK-SAME: %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor<?x32xf32> { 1163// CHECK-NEXT: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32> 1164// CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor<?x32xf32> 1165// CHECK-NEXT: return %[[CAST]] : tensor<?x32xf32> 1166func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> { 1167 %0 = tensor.cast %t : tensor<8x12x32xf32> to tensor<?x?x?xf32> 1168 %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32> 1169 %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<?x32xf32> 1170 return %2 : tensor<?x32xf32> 1171} 1172 1173// ----- 1174 1175func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { 1176 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4] 1177 : tensor<12x4xf32> into tensor<3x4x4xf32> 1178 %1 = tensor.collapse_shape %0 [[0, 1], [2]] 1179 : tensor<3x4x4xf32> into tensor<12x4xf32> 1180 return %1 : tensor<12x4xf32> 1181} 1182// CHECK-LABEL: @fold_collapse_of_expand 1183// CHECK-NOT: tensor.{{.*}}_shape 1184 1185// ----- 1186 1187func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index) 1188 -> tensor<?x?xf32> { 1189 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] 1190 : tensor<?x?xf32> into tensor<?x4x?xf32> 1191 %1 = tensor.collapse_shape %0 [[0, 1], [2]] 1192 : tensor<?x4x?xf32> into tensor<?x?xf32> 1193 return %1 : tensor<?x?xf32> 1194} 1195// CHECK-LABEL: @fold_collapse_of_expand_dynamic 1196// CHECK-NOT: tensor.{{.*}}_shape 1197 1198// ----- 1199 1200func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index) 1201 -> tensor<?x?xf32> { 1202 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3] 1203 : tensor<?x?xf32> into tensor<?x?x?xf32> 1204 %1 = tensor.collapse_shape %0 [[0, 1], [2]] 1205 : tensor<?x?x?xf32> into tensor<?x?xf32> 1206 return %1 : tensor<?x?xf32> 1207} 1208// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic 1209// CHECK-NOT: tensor.{{.*}}_shape 1210 1211// ----- 1212 1213func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) 1214 -> tensor<?x?x?xf32> { 1215 %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4] 1216 : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 1217 %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]] 1218 : tensor<?x?x?x?xf32> into tensor<?x?x?xf32> 1219 return %1 : tensor<?x?x?xf32> 1220} 1221// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic 1222// CHECK: tensor.expand_shape 1223// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape 1224// CHECK: return %[[COLLAPSE]] 1225 1226// ----- 1227 1228func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> { 1229 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] 1230 : tensor<3x4x4xf32> into tensor<12x4xf32> 1231 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4] 1232 : tensor<12x4xf32> into tensor<3x4x4xf32> 1233 return %1 : tensor<3x4x4xf32> 1234} 1235// CHECK-LABEL: @fold_expand_of_collapse 1236// CHECK-NOT: tensor.{{.*}}_shape 1237 1238// ----- 1239 1240func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index) 1241 -> tensor<?x4x?xf32> { 1242 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] 1243 : tensor<?x4x?xf32> into tensor<?x?xf32> 1244 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] 1245 : tensor<?x?xf32> into tensor<?x4x?xf32> 1246 return %1 : tensor<?x4x?xf32> 1247} 1248// CHECK-LABEL: @fold_expand_of_collapse_dynamic 1249// CHECK-NOT: tensor.{{.*}}_shape 1250 1251// ----- 1252 1253func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) 1254 -> tensor<?x?x?xf32> { 1255 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] 1256 : tensor<?x?x?xf32> into tensor<?x?xf32> 1257 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3] 1258 : tensor<?x?xf32> into tensor<?x?x?xf32> 1259 return %1 : tensor<?x?x?xf32> 1260} 1261// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic 1262// CHECK: tensor.collapse_shape 1263// CHECK: %[[EXPAND:.+]] = tensor.expand_shape 1264// CHECK: return %[[EXPAND]] 1265 1266// ----- 1267 1268func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -> tensor<?x384xf32> { 1269 %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32> 1270 %c0 = arith.constant 0 : index 1271 %dim = tensor.dim %collapsed, %c0 : tensor<?xf32> 1272 %c384= arith.constant 384 : index 1273 %div = arith.divui %dim, %c384 : index 1274 %expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor<?xf32> into tensor<?x384xf32> 1275 return %expanded : tensor<?x384xf32> 1276} 1277// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)> 1278// CHECK-LABEL: @compose_expand_of_collapse_last_two_dims 1279// CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32> 1280// CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index 1281// CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index 1282// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32> 1283// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32> 1284// CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]] 1285// CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index 1286// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor<?xf32> into tensor<?x384xf32> 1287// CHECK: return %[[RESULT]] 1288 1289// ----- 1290 1291func.func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) 1292 -> tensor<24x5x42x8xf32> { 1293 %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] 1294 : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> 1295 %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape [24, 5, 42, 8] 1296 : tensor<40320xf32> into tensor<24x5x42x8xf32> 1297 return %1 : tensor<24x5x42x8xf32> 1298} 1299// CHECK: func @compose_expand_of_collapse 1300// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32> 1301// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] 1302// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] 1303// CHECK: return %[[RESULT]] 1304 1305// ----- 1306 1307func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>) 1308 -> tensor<2x3x4x5x6x7x8xf32> { 1309 %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] 1310 : tensor<24x5x42x8xf32> into tensor<40320xf32> 1311 %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] output_shape [2, 3, 4, 5, 6, 7, 8] 1312 : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> 1313 return %1 : tensor<2x3x4x5x6x7x8xf32> 1314} 1315// CHECK: func @compose_expand_of_collapse_7D 1316// CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32> 1317// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]] 1318// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] 1319// CHECK: return %[[RESULT]] 1320 1321// ----- 1322 1323func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>, %arg1: index, %arg2: index, %arg3: index) 1324 -> tensor<?x?xi64> { 1325 %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] output_shape [%arg1, %arg2, %arg3, 1] 1326 : tensor<?x?x?xi64> into tensor<?x?x?x1xi64> 1327 %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] 1328 : tensor<?x?x?x1xi64> into tensor<?x?xi64> 1329 return %1 : tensor<?x?xi64> 1330} 1331// CHECK-LABEL: func @compose_collapse_of_expand 1332// CHECK: (%[[ARG:.*]]: tensor<?x?x?xi64>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) 1333// CHECK-NEXT: tensor.collapse_shape %[[ARG]] 1334// CHECK-SAME: [0, 1], [2] 1335// CHECK-SAME: : tensor<?x?x?xi64> into tensor<?x?xi64> 1336 1337// ----- 1338 1339func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>) 1340 -> tensor<4x512xf32> { 1341 %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 4, 1, 512] 1342 : tensor<2048xf32> into tensor<1x4x1x512xf32> 1343 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] 1344 : tensor<1x4x1x512xf32> into tensor<4x512xf32> 1345 return %1 : tensor<4x512xf32> 1346} 1347// CHECK: func @compose_collapse_of_expand_1D 1348// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [4, 512] 1349// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> 1350 1351// ----- 1352 1353func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>) 1354 -> tensor<1x1x1x1xf32> { 1355 %0 = tensor.collapse_shape %arg0 [] 1356 : tensor<1x1x1xf32> into tensor<f32> 1357 %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1, 1] 1358 : tensor<f32> into tensor<1x1x1x1xf32> 1359 return %1 : tensor<1x1x1x1xf32> 1360} 1361// CHECK: func @compose_expand_of_collapse_0_rank_to_expand 1362// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x1xf32> 1363// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]] 1364// CHECK-SAME: {{\[}}[0], [1], [2, 3]] output_shape [1, 1, 1, 1] 1365// CHECK: return %[[RESULT]] 1366 1367// ----- 1368 1369func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1xf32>) 1370 -> tensor<1x1x1xf32> { 1371 %0 = tensor.collapse_shape %arg0 [] 1372 : tensor<1x1x1x1xf32> into tensor<f32> 1373 %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1] 1374 : tensor<f32> into tensor<1x1x1xf32> 1375 return %1 : tensor<1x1x1xf32> 1376} 1377// CHECK: func @compose_expand_of_collapse_0_rank_to_collapse 1378// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x1x1xf32> 1379// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] 1380// CHECK-SAME: [0], [1], [2, 3] 1381// CHECK: return %[[RESULT]] 1382 1383// ----- 1384 1385func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> { 1386 %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16> 1387 %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16> 1388 return %expanded : tensor<4x32x10x128xf16> 1389} 1390 1391// CHECK-LABEL: func @compose_expand_of_collapse_static 1392// CHECK-SAME: %[[ARG0:.+]]: tensor<4x32x10x64x2xf16> 1393// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] 1394// CHECK-SAME: [0], [1], [2], [3, 4] 1395// CHECK: return %[[RESULT]] 1396 1397// ----- 1398 1399func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> { 1400 %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16> 1401 %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1, 10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16> 1402 return %expanded : tensor<4x?x10x128xf16> 1403} 1404 1405// CHECK-LABEL: func @compose_expand_of_collapse_dynamic 1406// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x10x64x2xf16> 1407// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] 1408// CHECK-SAME: [0], [1], [2], [3, 4] 1409// CHECK: return %[[RESULT]] 1410 1411// ----- 1412 1413// CHECK-LABEL: func @zero_rank_reshape_multi 1414func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> { 1415 // CHECK: return %arg0 1416 %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32> 1417 %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 1] : tensor<1xf32> into tensor<1x1xf32> 1418 %2 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32> 1419 return %2 : tensor<f32> 1420} 1421 1422// ----- 1423 1424func.func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>) 1425 -> tensor<?x?xf32> { 1426 %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] 1427 : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32> 1428 %1 = tensor.collapse_shape %0 [[0, 1], [2]] 1429 : tensor<?x?x?xf32> into tensor<?x?xf32> 1430 return %1 : tensor<?x?xf32> 1431} 1432// CHECK-LABEL: func @compose_collapse_of_collapse 1433// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] 1434// CHECK-NOT: tensor.collapse_shape 1435 1436// ----- 1437 1438func.func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>) 1439 -> tensor<f32> { 1440 %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] 1441 : tensor<1x1x1xf32> into tensor<1xf32> 1442 %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor<f32> 1443 return %1 : tensor<f32> 1444} 1445// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim 1446// CHECK: tensor.collapse_shape %{{.*}} [] 1447// CHECK-SAME: tensor<1x1x1xf32> into tensor<f32> 1448 1449// ----- 1450 1451func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { 1452 %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [1, 4, 1, 512] 1453 : tensor<4x512xf32> into tensor<1x4x1x512xf32> 1454 %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] 1455 : tensor<1x4x1x512xf32> into tensor<2048xf32> 1456 return %1 : tensor<2048xf32> 1457} 1458// CHECK: func @fold_collapse_of_expand_1D 1459// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] 1460// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> 1461 1462// ----- 1463 1464func.func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>) 1465 -> tensor<4x512x1x1xf32> { 1466 %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] output_shape [1, 4, 1, 512, 1, 1] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> 1467 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3], [4], [5]] 1468 : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> 1469 return %1 : tensor<4x512x1x1xf32> 1470} 1471// CHECK: func @fold_collapse_of_expand_unit_dims 1472// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] output_shape [4, 512, 1, 1] 1473// CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> 1474 1475// ----- 1476 1477func.func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>) 1478 -> tensor<4x512x1x512x4xf32> { 1479 %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] output_shape [1, 4, 1, 512, 1, 1, 512, 1, 4] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> 1480 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]] 1481 : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> 1482 return %1 : tensor<4x512x1x512x4xf32> 1483} 1484// CHECK: func @compose_collapse_of_expand_unit_dims 1485// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [4, 512, 1, 512, 4] 1486// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> 1487 1488// ----- 1489 1490func.func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) 1491 -> tensor<2x1xf32> { 1492 %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] 1493 : tensor<2xf32> into tensor<2x1x1xf32> 1494 %1 = tensor.collapse_shape %0 [[0], [1, 2]] 1495 : tensor<2x1x1xf32> into tensor<2x1xf32> 1496 return %1 : tensor<2x1xf32> 1497} 1498// CHECK: func @compose_collapse_of_expand_trailing_unit_dims 1499// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1] 1500// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> 1501 1502// ----- 1503 1504func.func @compose_collapse_of_collapse_unit_dims_dynamic( 1505 %arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> { 1506 %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]] 1507 : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32> 1508 %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]] 1509 : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32> 1510 return %1 : tensor<?x?x?x?xf32> 1511} 1512// CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic 1513// CHECK: tensor.collapse_shape 1514// CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8] 1515// CHECK-SAME: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32> 1516 1517// ----- 1518 1519func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) 1520 -> tensor<2x1xf32> { 1521 %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] : tensor<2xf32> into tensor<2x1x1xf32> 1522 %1 = tensor.collapse_shape %0 [[0], [1, 2]] 1523 : tensor<2x1x1xf32> into tensor<2x1xf32> 1524 return %1 : tensor<2x1xf32> 1525} 1526// CHECK: func @fold_collapse_of_expand_trailing_unit_dims 1527// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1] 1528// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> 1529 1530// ----- 1531 1532func.func @fold_collapse_of_collapse_trailing_unit_dims_dynamic( 1533 %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> { 1534 %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]] 1535 : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32> 1536 %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] 1537 : tensor<?x1x1x1xf32> into tensor<?xf32> 1538 return %1 : tensor<?xf32> 1539} 1540// CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic 1541// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]] 1542// CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor<?xf32> 1543 1544// ----- 1545 1546func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) 1547 -> tensor<12x42xf32> { 1548 %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [12, 42, 1, 1, 1] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32> 1549 %1 = tensor.collapse_shape %0 [[0], [1, 2, 3, 4]] 1550 : tensor<12x42x1x1x1xf32> into tensor<12x42xf32> 1551 return %1 : tensor<12x42xf32> 1552} 1553// CHECK: func @fold_collapse_of_expand_trailing_unit_dims 1554// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]] 1555// CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32> 1556 1557// ----- 1558 1559func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>, %sz0: index, %sz1: index, %sz2: index) 1560 -> tensor<?x?xf32> { 1561 %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [%sz0, %sz1, 1, %sz2] 1562 : tensor<?x?x?xf32> into tensor<?x?x1x?xf32> 1563 %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] 1564 : tensor<?x?x1x?xf32> into tensor<?x?xf32> 1565 return %1 : tensor<?x?xf32> 1566} 1567// CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle 1568// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32> 1569// CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]] 1570// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?xf32> 1571 1572// ----- 1573 1574func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>) 1575 -> tensor<2x6x16xf32> { 1576 %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [2, 2, 3, 2, 8] 1577 : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> 1578 %1 = tensor.collapse_shape %0 [[0], [1, 2], [3, 4]] 1579 : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> 1580 return %1 : tensor<2x6x16xf32> 1581} 1582// CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible 1583// CHECK: tensor.expand_shape 1584// CHECK: tensor.collapse_shape 1585 1586// ----- 1587 1588func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>) 1589 -> tensor<12x1xf32> { 1590 %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [3, 2, 2, 1] 1591 : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> 1592 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] 1593 : tensor<3x2x2x1xf32> into tensor<12x1xf32> 1594 return %1 : tensor<12x1xf32> 1595} 1596// CHECK: func @no_fold_collapse_of_expand_empty_expr 1597// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> 1598// CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]] 1599// CHECK-SAME: {{\[}}[0], [1], [2, 3]] output_shape [3, 2, 2, 1] 1600// CHECK: %[[RES:.+]] = tensor.collapse_shape %[[RARG0]] 1601// CHECK-SAME: [0, 1, 2], [3] 1602// CHECK: return %[[RES:.+]] : tensor<12x1xf32> 1603 1604// ----- 1605 1606func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> { 1607 %c0 = arith.constant dense<42> : tensor<2x8xi32> 1608 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2] 1609 : tensor<2x8xi32> into tensor<2x4x2xi32> 1610 return %0 : tensor<2x4x2xi32> 1611} 1612// CHECK-LABEL: @reshape_splat_constant_int32 1613// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32> 1614// CHECK-NOT: tensor.expand_shape 1615// CHECK: return %[[CST]] 1616// ----- 1617func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> { 1618 %c0 = tensor.splat %arg : tensor<2x4xf32> 1619 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, 2] 1620 : tensor<2x4xf32> into tensor<2x2x2xf32> 1621 return %0 : tensor<2x2x2xf32> 1622} 1623// CHECK-LABEL: @expand_shape_splat 1624// CHECK-SAME: %[[ARG0:.+]]: f32 1625// CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x2x2xf32> 1626// CHECK-NOT: tensor.expand_shape 1627// CHECK: return %[[CST]] 1628 1629// ----- 1630 1631// CHECK-LABEL: @expand_shape_splat_dynamic_no_fold 1632// CHECK-SAME: (%[[F:.+]]: f32, %[[M:.+]]: index, %[[SZ0:.+]]: index) 1633func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index, %sz0: index) -> tensor<2x2x?xf32> { 1634 // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] : tensor<2x?xf32> 1635 // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]] 1636 %c0 = tensor.splat %arg[%m] : tensor<2x?xf32> 1637 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, %sz0] : tensor<2x?xf32> into tensor<2x2x?xf32> 1638 return %0 : tensor<2x2x?xf32> 1639} 1640 1641// ----- 1642 1643func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> { 1644 %c0 = tensor.splat %arg : tensor<2x2x2xf32> 1645 %0 = tensor.collapse_shape %c0 [[0], [1, 2]] 1646 : tensor<2x2x2xf32> into tensor<2x4xf32> 1647 return %0 : tensor<2x4xf32> 1648} 1649// CHECK-LABEL: @collapse_shape_splat 1650// CHECK-SAME: %[[ARG0:.+]]: f32 1651// CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x4xf32> 1652// CHECK-NOT: tensor.collapse_shape 1653// CHECK: return %[[CST]] 1654 1655// ----- 1656 1657// CHECK-LABEL: @collapse_shape_splat_dynamic_no_fold 1658// CHECK-SAME: %[[F:.+]]: f32 1659// CHECK-SAME: %[[M:.+]]: index 1660func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x?xf32> { 1661 // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] 1662 // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SPLAT]] 1663 %c0 = tensor.splat %f[%m] : tensor<2x2x?xf32> 1664 %0 = tensor.collapse_shape %c0 [[0], [1, 2]] : tensor<2x2x?xf32> into tensor<2x?xf32> 1665 return %0 : tensor<2x?xf32> 1666} 1667 1668// ----- 1669 1670func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> { 1671 %c0 = arith.constant dense<42> : tensor<2x8xi16> 1672 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2] 1673 : tensor<2x8xi16> into tensor<2x4x2xi16> 1674 return %0 : tensor<2x4x2xi16> 1675} 1676// CHECK-LABEL: @reshape_splat_constant_int16 1677// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi16> 1678// CHECK-NOT: tensor.expand_shape 1679// CHECK: return %[[CST]] 1680 1681// ----- 1682 1683func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> { 1684 %c0 = arith.constant dense<42.0> : tensor<2x8xf32> 1685 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2] 1686 : tensor<2x8xf32> into tensor<2x4x2xf32> 1687 return %0 : tensor<2x4x2xf32> 1688} 1689// CHECK-LABEL: @reshape_splat_constant_float32 1690// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf32> 1691// CHECK-NOT: tensor.expand_shape 1692// CHECK: return %[[CST]] 1693 1694// ----- 1695 1696func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> { 1697 %c0 = arith.constant dense<42.0> : tensor<2x8xf64> 1698 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2] 1699 : tensor<2x8xf64> into tensor<2x4x2xf64> 1700 return %0 : tensor<2x4x2xf64> 1701} 1702// CHECK-LABEL: @reshape_splat_constant_float64 1703// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64> 1704// CHECK-NOT: tensor.expand_shape 1705// CHECK: return %[[CST]] 1706 1707// ----- 1708 1709// CHECK-LABEL: func @fold_rank 1710func.func @fold_rank() -> (index) { 1711 %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> 1712 : tensor<2x1x4xi32> 1713 1714 // Fold a ank into a constant 1715 // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index 1716 %rank_0 = tensor.rank %const_0 : tensor<2x1x4xi32> 1717 1718 // CHECK-NEXT: return [[C3]] 1719 return %rank_0 : index 1720} 1721 1722// ----- 1723 1724// CHECK-LABEL: func @pad_same_static_shape( 1725// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> 1726// CHECK-NOT: tensor.pad 1727// CHECK: return %[[ARG0]] 1728func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) 1729 -> tensor<5x6xf32> { 1730 %cst = arith.constant 0.000000e+00 : f32 1731 %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] { 1732 ^bb0(%arg1: index, %arg2: index): 1733 tensor.yield %cst : f32 1734 } : tensor<5x6xf32> to tensor<5x6xf32> 1735 return %0 : tensor<5x6xf32> 1736} 1737 1738// ----- 1739 1740// CHECK-LABEL: func @pad_fold_static( 1741// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> { 1742// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 1743// CHECK-NOT: arith.constant 4 : index 1744// CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] 1745// CHECK-SAME: low[0, 4, 1, 1] high[0, 4, 1, 1] { 1746// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): 1747// CHECK: tensor.yield %[[CST]] : f32 1748// CHECK: } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32> 1749// CHECK: tensor.cast 1750func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> { 1751 %c0 = arith.constant 0 : index 1752 %cst = arith.constant 0.000000e+00 : f32 1753 %padding = arith.constant 4 : index 1754 %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1] { 1755 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): 1756 tensor.yield %cst: f32 1757 } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32> 1758 return %padded : tensor<?x?x?x?xf32> 1759} 1760 1761// ----- 1762 1763// CHECK-LABEL: func @pad_nofold_same_static_shape( 1764// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> 1765// CHECK: %[[PAD:.*]] = tensor.pad 1766// CHECK: return %[[PAD]] 1767func.func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) 1768 -> tensor<5x6xf32> { 1769 %cst = arith.constant 0.000000e+00 : f32 1770 %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] { 1771 ^bb0(%arg1: index, %arg2: index): 1772 tensor.yield %cst : f32 1773 } : tensor<5x6xf32> to tensor<5x6xf32> 1774 return %0 : tensor<5x6xf32> 1775} 1776 1777// ----- 1778 1779// CHECK-LABEL: func @pad_after_cast_different_shape( 1780// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> { 1781// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 1782// CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] 1783// CHECK-SAME: low[0, 0, 1, 1] high[0, 0, 1, 1] { 1784// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): 1785// CHECK: tensor.yield %[[CST]] : f32 1786// CHECK: } : tensor<?x64x?x?xf32> to tensor<?x64x?x?xf32> 1787// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[PADDED:.*]] : 1788// CHECK-SAME: tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32> 1789// CHECK: return %[[DYNAMIC]] : tensor<?x?x?x?xf32> 1790// CHECK: } 1791func.func @pad_after_cast_different_shape(%arg0: tensor<?x64x?x?xf32>) 1792 -> tensor<?x?x?x?xf32> { 1793 %cst = arith.constant 0.000000e+00 : f32 1794 %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32> 1795 %padded = tensor.pad %dynamic low[0, 0, 1, 1] high[0, 0, 1, 1] { 1796 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): 1797 tensor.yield %cst: f32 1798 } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32> 1799 return %padded: tensor<?x?x?x?xf32> 1800} 1801 1802// ----- 1803 1804// CHECK-LABEL: func @pad_after_cast_same_shape( 1805// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>, 1806// CHECK-SAME: %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> { 1807// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 1808// CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] 1809// CHECK-SAME: low[0, %[[PADDING]], 1, 1] high[0, %[[PADDING]], 1, 1] { 1810// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): 1811// CHECK: tensor.yield %[[CST]] : f32 1812// CHECK: } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32> 1813// CHECK: return %[[PADDED:.*]] : tensor<?x?x?x?xf32> 1814// CHECK: } 1815func.func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index) 1816 -> tensor<?x?x?x?xf32> { 1817 %cst = arith.constant 0.000000e+00 : f32 1818 %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32> 1819 %padded = tensor.pad %dynamic low[0, %padding, 1, 1] high[0, %padding, 1, 1] { 1820 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): 1821 tensor.yield %cst: f32 1822 } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32> 1823 return %padded: tensor<?x?x?x?xf32> 1824} 1825 1826// ----- 1827 1828// CHECK-LABEL: func @pad_of_cast( 1829// CHECK-NOT: tensor.cast 1830// CHECK: tensor.pad 1831// CHECK: tensor<8x?xf32> to tensor<8x32xf32> 1832func.func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> { 1833 %c0 = arith.constant 0 : index 1834 %cst = arith.constant 0.000000e+00 : f32 1835 %0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32> 1836 %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %s] { 1837 ^bb0(%arg9: index, %arg10: index): 1838 tensor.yield %cst : f32 1839 } : tensor<?x?xf32> to tensor<8x32xf32> 1840 return %1 : tensor<8x32xf32> 1841} 1842 1843// ----- 1844 1845// CHECK-LABEL: @cast_of_pad_more_static 1846func.func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> { 1847 %cst = arith.constant 0.000000e+00 : f32 1848 // CHECK: %[[PAD:.*]] = tensor.pad 1849 // CHECK: tensor<?x?xf32> to tensor<32x32xf32> 1850 %padded = tensor.pad %arg0 low[%padding, %padding] high[0, 0] { 1851 ^bb0(%arg1: index, %arg2: index): 1852 tensor.yield %cst : f32 1853 } : tensor<?x?xf32> to tensor<?x?xf32> 1854 // CHECK-NOT: tensor.cast 1855 %casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32> 1856 // CHECK: return %[[PAD]] 1857 return %casted : tensor<32x32xf32> 1858} 1859 1860// ----- 1861 1862// CHECK-LABEL: @cast_of_pad_less_static 1863func.func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> { 1864 %cst = arith.constant 0.000000e+00 : f32 1865 // CHECK: tensor.pad 1866 %padded = tensor.pad %arg0 low[%padding, %padding, %padding] high[0, 0, 0] { 1867 ^bb0(%arg1: index, %arg2: index, %arg3: index): 1868 tensor.yield %cst : f32 1869 } : tensor<32x?x?xf32> to tensor<32x?x?xf32> 1870 // CHECK: %[[CAST:.*]] = tensor.cast 1871 %casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32> 1872 // CHECK: return %[[CAST]] 1873 return %casted : tensor<?x32x32xf32> 1874} 1875 1876// ----- 1877 1878func.func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { 1879 %c0 = arith.constant 0 : index 1880 %cst = arith.constant 0.0 : f32 1881 %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32> 1882 %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %c0] { 1883 ^bb0(%arg1: index, %arg2: index): 1884 tensor.yield %cst : f32 1885 } : tensor<?x?xf32> to tensor<4x4xf32> 1886 return %1 : tensor<4x4xf32> 1887} 1888// CHECK-LABEL: @pad_cast 1889// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32> 1890// CHECK: return %[[ARG0]] 1891 1892// ----- 1893 1894// CHECK-LABEL: func @fold_pad_source_cast( 1895// CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32> 1896// CHECK-NOT: tensor.cast 1897// CHECK: %[[RESULT:.*]] = tensor.pad %[[ARG0]] 1898func.func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> { 1899 %cst = arith.constant 0.0 : f32 1900 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32> 1901 %1 = tensor.pad %0 low[0, 0] high[0, 1] { 1902 ^bb0(%arg1: index, %arg2: index): 1903 tensor.yield %cst : f32 1904 } : tensor<?x?xf32> to tensor<4x4xf32> 1905 return %1 : tensor<4x4xf32> 1906} 1907 1908// ----- 1909 1910// CHECK-LABEL: func @pad_static_zero_cast( 1911// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32> 1912// CHECK-NOT: tensor.pad 1913// CHECK: %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32> 1914// CHECK: return %[[RESULT]] 1915func.func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> { 1916 %c0 = arith.constant 0 : index 1917 %0 = tensor.pad %arg0 low[0, %c0, 0] high[0, 0, %c0] { 1918 ^bb0(%arg1: index, %arg2: index, %arg3: index): 1919 tensor.yield %pad_value : f32 1920 } : tensor<?x?x?xf32> to tensor<2x3x4xf32> 1921 1922 return %0 : tensor<2x3x4xf32> 1923} 1924 1925// ----- 1926 1927// CHECK-LABEL: func @pad_nofold_static_zero( 1928// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32> 1929// CHECK: %[[PAD:.*]] = tensor.pad 1930// CHECK: return %[[PAD]] 1931func.func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> { 1932 %c0 = arith.constant 0 : index 1933 %0 = tensor.pad %arg0 nofold low[0, %c0, 0] high[0, 0, %c0] { 1934 ^bb0(%arg1: index, %arg2: index, %arg3: index): 1935 tensor.yield %pad_value : f32 1936 } : tensor<?x?x?xf32> to tensor<2x3x4xf32> 1937 1938 return %0 : tensor<2x3x4xf32> 1939} 1940 1941// ----- 1942 1943// CHECK-LABEL: func @fold_orthogonal_pad_chains( 1944// CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>, 1945// CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index 1946func.func @fold_orthogonal_pad_chains(%arg0: tensor<64x64xf32>, 1947 %sz0 : index, %sz1 : index, 1948 %pw0 : index, %pw1 : index) -> tensor<8x4xf32> { 1949 // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] 1950 // CHECK-SAME: [16, 4] [%[[SZ0]], %[[SZ1]]] 1951 // CHECK: %[[PAD:.*]] = tensor.pad %[[T0]] nofold 1952 // CHECK-SAME: high[%[[PW0]], %[[PW1]]] 1953 // CHECK: return %[[PAD]] 1954 %pad_value = arith.constant 0.0 : f32 1955 %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32> 1956 %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { 1957 ^bb0(%arg1: index, %arg2: index): 1958 tensor.yield %pad_value : f32 1959 } : tensor<?x64xf32> to tensor<8x64xf32> 1960 %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32> 1961 %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { 1962 ^bb0(%arg1: index, %arg2: index): 1963 tensor.yield %pad_value : f32 1964 } : tensor<8x?xf32> to tensor<8x4xf32> 1965 func.return %3 : tensor<8x4xf32> 1966} 1967 1968// ----- 1969 1970// CHECK-LABEL: func @dont_fold_pad_chains( 1971// CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>, 1972// CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index 1973func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>, 1974 %sz0 : index, %sz1 : index, 1975 %pw0 : index, %pw1 : index) -> (tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>) { 1976 // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] 1977 // CHECK: %[[T1:.*]] = tensor.pad %[[T0]] 1978 %pad_value = arith.constant 0.0 : f32 1979 %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32> 1980 %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { 1981 ^bb0(%arg1: index, %arg2: index): 1982 tensor.yield %pad_value : f32 1983 } : tensor<?x64xf32> to tensor<8x64xf32> 1984 1985 // Don't fold if the padding values are different. 1986 // CHECK: %[[T2:.*]] = tensor.extract_slice %[[T1]] 1987 // CHECK-SAME: [0, 4] [8, %[[SZ1]]] 1988 // CHECK: %[[PAD0:.*]] = tensor.pad %[[T2]] 1989 %different_value = arith.constant 1.0 : f32 1990 %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32> 1991 %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { 1992 ^bb0(%arg1: index, %arg2: index): 1993 tensor.yield %different_value : f32 1994 } : tensor<8x?xf32> to tensor<8x4xf32> 1995 1996 // Don't fold if the pad ops have common padding dimensions. 1997 // CHECK: %[[T3:.*]] = tensor.extract_slice %[[T1]] 1998 // CHECK-SAME: [4, 0] [%[[SZ1]], 64] 1999 // CHECK: %[[PAD1:.*]] = tensor.pad %[[T3]] 2000 %4 = tensor.extract_slice %1[4, 0] [%sz1, 64] [1, 1] : tensor<8x64xf32> to tensor<?x64xf32> 2001 %5 = tensor.pad %4 nofold low[0, 0] high[%pw1, 0] { 2002 ^bb0(%arg1: index, %arg2: index): 2003 tensor.yield %pad_value : f32 2004 } : tensor<?x64xf32> to tensor<4x64xf32> 2005 2006 // Don't fold if padded source tensor dimension is accessed at an offset. 2007 // CHECK: %[[T4:.*]] = tensor.extract_slice %[[T1]] 2008 // CHECK-SAME: [%[[SZ0]], 4] [8, %[[SZ1]] 2009 // CHECK: %[[PAD2:.*]] = tensor.pad %[[T4]] 2010 %6 = tensor.extract_slice %1[%sz0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32> 2011 %7 = tensor.pad %6 nofold low[0, 0] high[0, %pw1] { 2012 ^bb0(%arg1: index, %arg2: index): 2013 tensor.yield %pad_value : f32 2014 } : tensor<8x?xf32> to tensor<8x4xf32> 2015 2016 // Don't fold if a padded source tensor dimension is sliced. 2017 // CHECK: %[[T5:.*]] = tensor.extract_slice %[[T1]] 2018 // CHECK-SAME: [0, 4] [6, %[[SZ1]] 2019 // CHECK: %[[PAD3:.*]] = tensor.pad %[[T5]] 2020 %8 = tensor.extract_slice %1[0, 4] [6, %sz1] [1, 1] : tensor<8x64xf32> to tensor<6x?xf32> 2021 %9 = tensor.pad %8 nofold low[0, 0] high[0, %pw1] { 2022 ^bb0(%arg1: index, %arg2: index): 2023 tensor.yield %pad_value : f32 2024 } : tensor<6x?xf32> to tensor<6x4xf32> 2025 2026 // CHECK: return %[[PAD0]], %[[PAD1]], %[[PAD2]], %[[PAD3]] 2027 func.return %3, %5, %7, %9 : tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32> 2028} 2029 2030// ----- 2031 2032// CHECK-LABEL: func @merge_constant_padding 2033// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32> 2034// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32 2035// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2] 2036// CHECK: tensor.yield %[[PADVAL]] 2037// CHECK: return %[[PAD]] 2038func.func @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> { 2039 %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] { 2040 ^bb0(%b0: index, %b1 : index): 2041 tensor.yield %pad_value : f32 2042 } : tensor<2x3xf32> to tensor<4x4xf32> 2043 %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] { 2044 ^bb0(%b2: index, %b3 : index): 2045 tensor.yield %pad_value : f32 2046 } : tensor<4x4xf32> to tensor<7x8xf32> 2047 return %pad1 : tensor<7x8xf32> 2048} 2049 2050// ----- 2051 2052// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1)> 2053// CHECK-LABEL: func @merge_constant_padding_dynamic 2054// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32> 2055// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index 2056// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32 2057// CHECK: %[[HIGH:.+]] = affine.apply #[[$MAP]]()[%[[IDX]]] 2058// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2] 2059// CHECK: tensor.yield %[[PADVAL]] 2060// CHECK: return %[[PAD]] 2061func.func @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> { 2062 %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] { 2063 ^bb0(%b0: index, %b1 : index): 2064 tensor.yield %pad_value : f32 2065 } : tensor<?x?xf32> to tensor<?x?xf32> 2066 %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] { 2067 ^bb0(%b2: index, %b3 : index): 2068 tensor.yield %pad_value : f32 2069 } : tensor<?x?xf32> to tensor<?x?xf32> 2070 return %pad1 : tensor<?x?xf32> 2071} 2072 2073// ----- 2074 2075// Verify that folding does not happen if it would drop a nofold attribute 2076// CHECK-LABEL: func @dont_merge_constant_padding_nofold 2077// CHECK: tensor.pad {{.*}} nofold 2078// CHECK: tensor.pad 2079func.func @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> { 2080 %pad0 = tensor.pad %arg0 nofold low[1, 1] high[1, 0] { 2081 ^bb0(%b0: index, %b1 : index): 2082 tensor.yield %pad_value : f32 2083 } : tensor<2x3xf32> to tensor<4x4xf32> 2084 %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] { 2085 ^bb0(%b2: index, %b3 : index): 2086 tensor.yield %pad_value : f32 2087 } : tensor<4x4xf32> to tensor<7x8xf32> 2088 return %pad1 : tensor<7x8xf32> 2089} 2090 2091// ----- 2092 2093// Verify that folding does not happen if it would drop a nofold attribute 2094// CHECK-LABEL: func @dont_merge_constant_padding_different_vals 2095// CHECK: tensor.pad 2096// CHECK: tensor.pad 2097func.func @dont_merge_constant_padding_different_vals( 2098 %arg0: tensor<2x3xf32>, 2099 %pad_value0: f32, 2100 %pad_value1: f32) -> tensor<7x8xf32> { 2101 %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] { 2102 ^bb0(%b0: index, %b1 : index): 2103 tensor.yield %pad_value0 : f32 2104 } : tensor<2x3xf32> to tensor<4x4xf32> 2105 %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] { 2106 ^bb0(%b2: index, %b3 : index): 2107 tensor.yield %pad_value1 : f32 2108 } : tensor<4x4xf32> to tensor<7x8xf32> 2109 return %pad1 : tensor<7x8xf32> 2110} 2111 2112// ----- 2113 2114// CHECK-LABEL: func @fold_collapse_shape_from_elements 2115func.func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> { 2116 // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32> 2117 // CHECK: return %[[FROM]] : tensor<i32> 2118 %0 = tensor.from_elements %arg0 : tensor<1xi32> 2119 %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32> 2120 return %1 : tensor<i32> 2121} 2122 2123// ----- 2124 2125// CHECK-LABEL: func @fold_expand_shape_from_elements 2126func.func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> { 2127 // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32> 2128 // CHECK: return %[[FROM]] : tensor<1xi32> 2129 %0 = tensor.from_elements %arg0 : tensor<i32> 2130 %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<i32> into tensor<1xi32> 2131 return %1 : tensor<1xi32> 2132} 2133 2134// ----- 2135 2136// CHECK-LABEL: func @propagate_index_cast 2137func.func @propagate_index_cast(%arg0: tensor<1xi32>) -> index { 2138 // CHECK: %[[IDX:.+]] = arith.constant 0 2139 // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32> 2140 // CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]] 2141 // CHECK: return %[[CAST]] : index 2142 %c0 = arith.constant 0 : index 2143 %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex> 2144 %1 = tensor.extract %0[%c0] : tensor<1xindex> 2145 return %1 : index 2146} 2147 2148// ----- 2149 2150// CHECK-LABEL: func @splat_fold 2151func.func @splat_fold() -> tensor<4xf32> { 2152 %c = arith.constant 1.0 : f32 2153 %t = tensor.splat %c : tensor<4xf32> 2154 return %t : tensor<4xf32> 2155 2156 // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32> 2157 // CHECK-NEXT: return [[T]] : tensor<4xf32> 2158} 2159 2160// ----- 2161 2162// CHECK-LABEL: func @splat_dynamic_no_fold 2163// CHECK-SAME: %[[M:.+]]: index 2164func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> { 2165 // CHECK: %[[F:.+]] = arith.constant 2166 %f = arith.constant 1.0 : f32 2167 2168 // CHECK: tensor.splat %[[F]][%[[M]]] : tensor<4x?xf32> 2169 %t = tensor.splat %f[%m] : tensor<4x?xf32> 2170 return %t : tensor<4x?xf32> 2171} 2172 2173// ----- 2174 2175// CHECK-LABEL: func @cast_extract_slice 2176func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index) 2177 -> tensor<16x512xf32> { 2178// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 512] [1, 1] : tensor<128x512xf32> to tensor<16x512xf32> 2179 %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : tensor<128x512xf32> to tensor<?x512xf32> 2180 %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32> 2181// CHECK: return %[[E]] : tensor<16x512xf32> 2182 return %1 : tensor<16x512xf32> 2183} 2184 2185// ----- 2186 2187// CHECK-LABEL: func @cast_extract_slice_rank_reduce 2188func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : index, %o : index) 2189 -> tensor<16xf32> { 2190// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 1] [1, 1] : tensor<128x512xf32> to tensor<16xf32> 2191 %0 = tensor.extract_slice %arg0[%o, 0] [%s, 1] [1, 1] : tensor<128x512xf32> to tensor<?xf32> 2192 %1 = tensor.cast %0 : tensor<?xf32> to tensor<16xf32> 2193// CHECK: return %[[E]] : tensor<16xf32> 2194 return %1 : tensor<16xf32> 2195} 2196 2197// ----- 2198 2199// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices( 2200// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, 2201// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>, 2202// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index 2203func.func @canonicalize_parallel_insert_slice_indices( 2204 %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>, 2205 %num_threads : index) -> tensor<?x?xf32> 2206{ 2207 %cst = arith.constant 4.200000e+01 : f32 2208 %c0 = arith.constant 0 : index 2209 %c1 = arith.constant 1 : index 2210 2211 // CHECK-NOT: tensor.cast 2212 // CHECK: scf.forall (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<?x?xf32>) { 2213 // CHECK-NEXT: scf.forall.in_parallel { 2214 // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[o]][%[[tidx]], 0] [1, 5] [1, 1] 2215 %2 = scf.forall (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) { 2216 %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor<?x5xf32> 2217 scf.forall.in_parallel { 2218 tensor.parallel_insert_slice %3 into %o[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32> 2219 } 2220 } 2221 return %2 : tensor<?x?xf32> 2222} 2223 2224// ----- 2225 2226// CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice 2227// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>) 2228func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> { 2229 %c0 = arith.constant 0 : index 2230 %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> 2231 %1 = tensor.insert_slice %0 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> 2232 // CHECK: return %[[INPUT]] 2233 return %1: tensor<1x2x2x4xf32> 2234} 2235 2236// ----- 2237 2238// CHECK-LABEL: func.func @dont_fold_mismatched_source_dst 2239func.func @dont_fold_mismatched_source_dst(%input0: tensor<1x2x2x4xf32>, %input1: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> { 2240 %c0 = arith.constant 0 : index 2241 // CHECK: tensor.extract_slice 2242 %0 = tensor.extract_slice %input0[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> 2243 // CHECK: tensor.insert_slice 2244 %1 = tensor.insert_slice %0 into %input1[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> 2245 return %1: tensor<1x2x2x4xf32> 2246} 2247 2248// ----- 2249 2250// CHECK-LABEL: func.func @dont_fold_mismatched_parameters 2251func.func @dont_fold_mismatched_parameters(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> { 2252 %c0 = arith.constant 0 : index 2253 // CHECK: tensor.extract_slice 2254 %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> 2255 // CHECK: tensor.insert_slice 2256 %1 = tensor.insert_slice %0 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> 2257 return %1: tensor<1x2x2x4xf32> 2258} 2259 2260// ----- 2261 2262func.func @empty_canonicalize() -> (tensor<4x5x?xf32>) { 2263 %c6 = arith.constant 6 : index 2264 %0 = tensor.empty(%c6) : tensor<4x5x?xf32> 2265 return %0 : tensor<4x5x?xf32> 2266} 2267// CHECK: func @empty_canonicalize 2268// CHECK: %[[T0:.+]] = tensor.empty() : tensor<4x5x6xf32> 2269// CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> 2270// CHECK: return %[[T1]] 2271 2272// ----- 2273 2274func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> { 2275 %0 = tensor.empty(%arg0) : tensor<?x12xf32> 2276 %1 = tensor.cast %0 : tensor<?x12xf32> to tensor<1x12xf32> 2277 return %1 : tensor<1x12xf32> 2278} 2279// CHECK: func @fold_empty_tensor_with_cast(%[[ARG0:.+]]: index) 2280// CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32> 2281// CHECK: return %[[T0]] : tensor<1x12xf32> 2282 2283// ----- 2284 2285func.func private @some_use(%i : index, %j : index) 2286 2287// CHECK-LABEL: func @empty_tensor_canonicalize 2288// CHECK-SAME: %[[I:.*]]: index 2289func.func @empty_tensor_canonicalize(%i : index) { 2290 %c0 = arith.constant 0 : index 2291 %c1 = arith.constant 1 : index 2292 2293 // CHECK-NOT: tensor.empty 2294 %0 = tensor.empty(%i) : tensor<?x42xf32> 2295 2296 // CHECK-NOT: tensor.dim 2297 %1 = tensor.dim %0, %c0: tensor<?x42xf32> 2298 %2 = tensor.dim %0, %c1: tensor<?x42xf32> 2299 2300 // CHECK: %[[c42:.*]] = arith.constant 42 : index 2301 // CHECK: call @some_use(%[[I]], %[[c42]]) 2302 call @some_use(%1, %2) : (index, index) -> () 2303 2304 return 2305} 2306 2307// ----- 2308 2309// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)> 2310// CHECK-LABEL: func @dim_of_expand_shape( 2311// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32> 2312// CHECK: %[[c1:.*]] = arith.constant 1 : index 2313// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32> 2314// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]] 2315// CHECK: return %[[apply]] 2316func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index { 2317 %c2 = arith.constant 2 : index 2318 %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8] 2319 : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32> 2320 %1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32> 2321 return %1 : index 2322} 2323 2324// ----- 2325 2326// CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)> 2327// CHECK-LABEL: func @dim_of_collapse_shape( 2328// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x7x?xf32> 2329// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index 2330// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index 2331// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index 2332// CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]] 2333// CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]] 2334// CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]] 2335// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]] 2336// CHECK: return %[[apply]] 2337func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index { 2338 %c1 = arith.constant 1 : index 2339 %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]] 2340 : tensor<?x?x?x7x?xf32> into tensor<?x?xf32> 2341 %1 = tensor.dim %0, %c1 : tensor<?x?xf32> 2342 return %1 : index 2343} 2344 2345// ----- 2346 2347// Can't fold when dim is out of bound. 2348// CHECK-LABEL: func @out_of_bound_dim_of_collapse_shape( 2349// CHECK: %[[DIM:.*]] = tensor.dim 2350// CHECK: return %[[DIM]] 2351func.func @out_of_bound_dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index { 2352 %c5 = arith.constant 5 : index 2353 %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]] 2354 : tensor<?x?x?x7x?xf32> into tensor<?x?xf32> 2355 %1 = tensor.dim %0, %c5 : tensor<?x?xf32> 2356 return %1 : index 2357} 2358 2359// ----- 2360 2361// CHECK-LABEL: func @collapse_expand_fold_to_cast( 2362// CHECK-SAME: %[[t:.*]]: tensor<?xf32> 2363// CHECK: return %[[t]] 2364func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>, %sz0: index) -> (tensor<?xf32>) 2365{ 2366 %0 = tensor.expand_shape %t [[0, 1]] output_shape [1, %sz0] : tensor<?xf32> into tensor<1x?xf32> 2367 %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32> 2368 return %1 : tensor<?xf32> 2369} 2370 2371// ----- 2372 2373// Chain: NC -> NCnc -> NCnc -> NC 2374// CHECK: func.func @unpack_pack( 2375// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>) 2376// CHECK: return %[[T]] : tensor<128x128xf32> 2377func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> { 2378 %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32> 2379 %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32> 2380 %tensor_empty1 = tensor.empty() : tensor<128x128xf32> 2381 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor<128x128xf32> 2382 return %unpacked : tensor<128x128xf32> 2383} 2384 2385// ----- 2386 2387// Chain: NC -> NCcn -> NCnc -> NC 2388// CHECK: func.func @unpack_pack( 2389// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>) 2390// CHECK-NOT: return %[[T]] : tensor<128x128xf32> 2391func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> { 2392 %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32> 2393 %packed = tensor.pack %t inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32> 2394 %tensor_empty1 = tensor.empty() : tensor<128x128xf32> 2395 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor 2396<128x128xf32> 2397 return %unpacked : tensor<128x128xf32> 2398} 2399 2400// ----- 2401 2402// Chain: NC -> CNcn -> NCnc -> NC 2403// CHECK: func.func @unpack_pack( 2404// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>) 2405// CHECK-NOT: return %[[T]] : tensor<128x128xf32> 2406func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> { 2407 %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32> 2408 %packed = tensor.pack %t outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32> 2409 %tensor_empty1 = tensor.empty() : tensor<128x128xf32> 2410 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor 2411<128x128xf32> 2412 return %unpacked : tensor<128x128xf32> 2413} 2414 2415// ----- 2416 2417// Chain: NC -> NCnc -> NCnc -> NC 2418// CHECK: func.func @unpack_pack( 2419// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>, 2420// CHECK: return %[[T]] : tensor<128x128xf32> 2421func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> { 2422 %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32> 2423 %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x?x?xf32> 2424 %tensor_empty1 = tensor.empty() : tensor<128x128xf32> 2425 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<16x16x?x?xf32> -> tensor 2426<128x128xf32> 2427 return %unpacked : tensor<128x128xf32> 2428} 2429 2430// ----- 2431 2432// CHECK: func.func @unpack_pack_with_padding_no_canonicalization( 2433// CHECK: tensor.pack 2434// CHECK: tensor.unpack 2435func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> { 2436 %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16> 2437 %tensor_empty1 = tensor.empty() : tensor<224x512xbf16> 2438 %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16> 2439 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16> 2440 return %unpacked : tensor<224x512xbf16> 2441} 2442 2443// ----- 2444 2445// Chain NCnc -> NC -> NC -> NCnc 2446// CHECK: func.func @pack_unpack( 2447// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>, 2448// CHECK: return %[[T]] : tensor<16x16x?x?xf32> 2449func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> { 2450 %tensor_empty = tensor.empty() : tensor<128x128xf32> 2451 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32> 2452 %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32> 2453 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32> 2454 return %packed : tensor<16x16x?x?xf32> 2455} 2456 2457// ----- 2458 2459// Chain NCnc -> NC -> NC -> NCnc 2460// CHECK: func.func @pack_unpack( 2461// CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32> 2462// CHECK: return %[[T]] : tensor<16x16x8x8xf32> 2463func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> { 2464 %tensor_empty = tensor.empty() : tensor<128x128xf32> 2465 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32> 2466 %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32> 2467 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32> 2468 return %packed : tensor<16x16x8x8xf32> 2469} 2470 2471// ----- 2472 2473// CHECK: func.func @pack_unpack_same_tiles( 2474// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>, 2475// CHECK: return %[[T]] : tensor<?x?x?x?xf32> 2476func.func @pack_unpack_same_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index, 2477 %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> { 2478 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32> 2479 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32> 2480 %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32> 2481 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32> 2482 return %packed : tensor<?x?x?x?xf32> 2483} 2484 2485// ----- 2486 2487// CHECK: func.func @pack_unpack_different_tiles( 2488// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>, 2489// CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32> 2490func.func @pack_unpack_different_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index, 2491 %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> { 2492 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32> 2493 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32> 2494 %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32> 2495 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile2, %tile1] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32> 2496 return %packed : tensor<?x?x?x?xf32> 2497} 2498 2499// ----- 2500 2501// CHECK: func.func @pack_unpack_dynamic_with_padding( 2502// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>, 2503// CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32> 2504func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index, 2505 %tile1: index, %tile2: index, %pad: f32) -> tensor<?x?x?x?xf32> { 2506 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32> 2507 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32> 2508 %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32> 2509 %packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32> 2510 return %packed : tensor<?x?x?x?xf32> 2511} 2512 2513// ----- 2514 2515// CHECK: func.func @pack_outer_dims_unpack_no_outer_dims( 2516// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>, 2517// CHECK: return %[[T]] : tensor<16x16x?x?xf32> 2518func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> { 2519 %tensor_empty = tensor.empty() : tensor<128x128xf32> 2520 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32> 2521 %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32> 2522 %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32> 2523 return %packed : tensor<16x16x?x?xf32> 2524} 2525 2526// ----- 2527 2528// CHECK: func.func @pack_no_outer_dims_unpack_outer_dims( 2529// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>, 2530// CHECK: return %[[T]] : tensor<16x16x?x?xf32> 2531func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> { 2532 %tensor_empty = tensor.empty() : tensor<128x128xf32> 2533 %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32> 2534 %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32> 2535 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32> 2536 return %packed : tensor<16x16x?x?xf32> 2537} 2538 2539// ----- 2540 2541// CHECK: func.func @invalid_empty_negative_size 2542// CHECK: %[[IDX:.*]] = index.constant 2543// CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32> 2544func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) { 2545 %c1 = arith.constant 1 : index 2546 %cn2 = arith.constant 2 : index 2547 %0 = index.sub %c1, %cn2 2548 %1 = tensor.empty(%0) : tensor<4x5x?xf32> 2549 return %1 : tensor<4x5x?xf32> 2550} 2551 2552// ----- 2553 2554// Fold DstStyleOp -> tensor.unpack operations. 2555func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> { 2556 %cst = arith.constant 0.0 : f32 2557 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32> 2558 %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32> 2559 return %unpack : tensor<?x?xf32> 2560} 2561// CHECK-LABEL: func @fold_dst_style_ops_into_unpack 2562// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x16x64xf32> 2563// CHECK-SAME: %[[INIT:.+]]: tensor<?x?xf32> 2564// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 2565// CHECK-SAME: into %[[INIT]] 2566// CHECK: return %[[UNPACK]] 2567 2568// ----- 2569 2570// The IR in this test case in invalid. This test tests that the canonicalizer 2571// does not crash. 2572 2573// CHECK-LABEL: func @invalid_slice_ops( 2574// CHECK: %[[c:.*]] = arith.constant -5 : index 2575// CHECK: tensor.extract_slice {{.*}}%[[c]] 2576// CHECK: tensor.insert_slice {{.*}}%[[c]] 2577func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> { 2578 %c = arith.constant -5 : index 2579 %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32> 2580 %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32> 2581 return %1 : tensor<?xf32> 2582} 2583 2584// ----- 2585 2586// CHECK-LABEL: func @generate_negative_size_verifies( 2587// CHECK: %[[c:.*]] = arith.constant -8 : index 2588// CHECK: tensor.generate %[[c]] 2589// CHECK: : tensor<?x8xi32> 2590func.func @generate_negative_size_verifies() -> tensor<?x8xi32> { 2591 %cst = arith.constant 0 : i32 2592 %c0 = arith.constant 0 : index 2593 %size = affine.max affine_map<(d0) -> (d0 mod 64 - 8)>(%c0) 2594 %tensor = tensor.generate %size { 2595 ^bb0(%arg0: index, %arg1: index): 2596 tensor.yield %cst : i32 2597 } : tensor<?x8xi32> 2598 return %tensor : tensor<?x8xi32> 2599} 2600 2601// ----- 2602 2603func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> { 2604 %dim1 = arith.constant 40 : index 2605 %dim2 = arith.constant 80 : index 2606 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32> 2607 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32> 2608 %cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32> 2609 %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32> 2610 %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32> 2611 return %packed : tensor<10x20x4x4xf32> 2612} 2613// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles 2614// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] 2615// CHECK: return %[[SRC]] 2616 2617// ----- 2618 2619// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx] 2620// CHECK-LABEL: func @dim_of_reshape( 2621// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>, 2622// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex> 2623// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 2624// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]] 2625// CHECK-NOT: tensor.store 2626// CHECK-NOT: tensor.dim 2627// CHECK-NOT: tensor.reshape 2628// CHECK: return %[[DIM]] : index 2629func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) 2630 -> index { 2631 %c3 = arith.constant 3 : index 2632 %0 = tensor.reshape %arg0(%arg1) 2633 : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> 2634 // Update the shape to test that the load ends up in the right place. 2635 tensor.insert %c3 into %arg1[%c3] : tensor<?xindex> 2636 %1 = tensor.dim %0, %c3 : tensor<*xf32> 2637 return %1 : index 2638} 2639 2640// ----- 2641 2642// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx] 2643// CHECK-LABEL: func @dim_of_reshape_i32( 2644// CHECK: tensor.extract 2645// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast 2646// CHECK-NOT: tensor.dim 2647// CHECK-NOT: tensor.reshape 2648// CHECK: return %[[CAST]] : index 2649func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>) 2650 -> index { 2651 %c3 = arith.constant 3 : index 2652 %0 = tensor.reshape %arg0(%arg1) 2653 : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32> 2654 %1 = tensor.dim %0, %c3 : tensor<*xf32> 2655 return %1 : index 2656} 2657 2658// ----- 2659 2660// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx] 2661// CHECK-LABEL: func @dim_of_reshape_for( 2662// CHECK: scf.for 2663// CHECK-NEXT: tensor.extract 2664// CHECK-NOT: tensor.dim 2665// CHECK-NOT: tensor.reshape 2666func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index { 2667 %c0 = arith.constant 0 : index 2668 %c1 = arith.constant 1 : index 2669 %c4 = arith.constant 4 : index 2670 2671 %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> 2672 2673 %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) { 2674 %2 = tensor.dim %0, %arg2 : tensor<*xf32> 2675 %3 = arith.muli %arg3, %2 : index 2676 scf.yield %3 : index 2677 } 2678 return %1 : index 2679} 2680 2681// ----- 2682 2683// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx] 2684// CHECK-LABEL: func @dim_of_reshape_undominated( 2685// CHECK: arith.muli 2686// CHECK-NEXT: tensor.extract 2687// CHECK-NOT: tensor.dim 2688// CHECK-NOT: tensor.reshape 2689func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index { 2690 %c4 = arith.constant 4 : index 2691 %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> 2692 %0 = arith.muli %arg2, %c4 : index 2693 %dim = tensor.dim %reshape, %0 : tensor<*xf32> 2694 return %dim : index 2695 } 2696 2697// ----- 2698 2699// CHECK-LABEL: @reshape_fold_2d 2700// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32> 2701func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> { 2702 %c0 = arith.constant 0 : index 2703 %c1 = arith.constant 1 : index 2704 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32> 2705 %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32> 2706 %ds = tensor.from_elements %d0, %d1 : tensor<2xindex> 2707 %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32> 2708 // CHECK: return %[[ARG0]] 2709 return %reshape : tensor<?x?xi32> 2710} 2711 2712// ----- 2713 2714// CHECK-LABEL: @reshape_nofold_2d 2715// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32> 2716func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> { 2717 %c0 = arith.constant 0 : index 2718 %c1 = arith.constant 1 : index 2719 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32> 2720 %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32> 2721 %ds = tensor.from_elements %d1, %d0 : tensor<2xindex> 2722 // CHECK: tensor.reshape 2723 %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32> 2724 return %reshape : tensor<?x?xi32> 2725} 2726 2727// ----- 2728 2729// CHECK-LABEL: @reshape_nofold_2d_ins 2730func.func @reshape_nofold_2d_ins(%arg0 : tensor<?x?xi32>, %arg1: index, %arg2: index) -> tensor<?x?xi32> { 2731 %ds = tensor.from_elements %arg1, %arg2 : tensor<2xindex> 2732 // CHECK: tensor.reshape 2733 %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32> 2734 return %reshape : tensor<?x?xi32> 2735} 2736 2737// ----- 2738 2739// CHECK-LABEL: @reshape_fold_3d_cst 2740// CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32> 2741func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> { 2742 %c1 = arith.constant 1 : index 2743 %c2 = arith.constant 2 : index 2744 %d0 = arith.constant 5 : index 2745 %d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32> 2746 %d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32> 2747 %ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex> 2748 %reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32> 2749 // CHECK: return %[[ARG0]] 2750 return %reshape : tensor<5x?x?xi32> 2751} 2752 2753// ----- 2754 2755// Test case: This test fails to fold because the index of tensor.dim is out_of_bounds 2756// CHECK-LABEL: func @dim_out_of_bounds( 2757// CHECK: %[[IDX:.*]] = index.constant 28 2758// CHECK-NEXT: bufferization.alloc_tensor 2759// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[IDX]] 2760// CHECK-NEXT: memref.alloc 2761// CHECK-NEXT: memref.cast 2762// CHECK-NEXT: affine.vector_load %{{.*}}[{{.*}}, {{.*}}, symbol(%[[DIM]])] 2763// CHECK-NEXT: return 2764func.func @dim_out_of_bounds() -> vector<7xi32> { 2765 %c1 = arith.constant 1 : index 2766 %idx28 = index.constant 28 2767 %c29 = arith.constant 29 : index 2768 %3 = bufferization.alloc_tensor(%c29) : tensor<?xi16> 2769 %dim = tensor.dim %3, %idx28 : tensor<?xi16> 2770 %alloc_21 = memref.alloc(%c29) : memref<?x26x2xi32> 2771 %16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32> 2772 return %16 : vector<7xi32> 2773} 2774 2775// ----- 2776 2777// CHECK-LABEL: func.func @fold_cast_multiple_results( 2778// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>, 2779// CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index { 2780// CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>) 2781// CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index 2782// CHECK: return %[[RES]]#1 : index 2783func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index { 2784 %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32> 2785 %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32> 2786 %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index 2787 return %0#1 : index 2788} 2789 2790// ----- 2791 2792// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size 2793// CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>, 2794// CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>, 2795// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> { 2796// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) 2797// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] 2798// CHECK-SAME: test_attr 2799// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32> 2800// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32> 2801func.func @fold_cast_pack_dynamic_tile_size( 2802 %dest: tensor<1x1x8x1xi32>, 2803 %src: tensor<7x?xi32>, 2804 %pad: i32) -> tensor<1x1x8x1xi32> { 2805 2806 %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> 2807 %c8 = arith.constant 8 : index 2808 %pack = tensor.pack %src padding_value(%pad : i32) 2809 inner_dims_pos = [0, 1] 2810 inner_tiles = [%c8, 1] 2811 into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> 2812 %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> 2813 return %res : tensor<1x1x8x1xi32> 2814} 2815 2816// ----- 2817 2818// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size( 2819// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>, 2820// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> { 2821// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> 2822// CHECK: return %[[RES]] : tensor<7x?xi32> 2823func.func @fold_cast_unpack_dynamic_tile_size( 2824 %src: tensor<1x1x8x1xi32>, 2825 %res: tensor<7x?xi32>) -> tensor<7x?xi32> { 2826 2827 %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> 2828 %c8 = arith.constant 8 : index 2829 %unpack = tensor.unpack %cast 2830 inner_dims_pos = [0, 1] 2831 inner_tiles = [%c8, 1] 2832 into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> 2833 return %unpack : tensor<7x?xi32> 2834} 2835 2836// ----- 2837 2838// CHECK-LABEL: func.func @pack_dont_drop_attributes( 2839// CHECK: tensor.pack {{.*}} {test_attr} 2840func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> { 2841 %c32_i64 = arith.constant 32 : i64 2842 %cst = arith.constant 0.000000e+00 : f16 2843 %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16> 2844 return %pack : tensor<128x?x100x16x1xf16> 2845} 2846 2847// ----- 2848 2849func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>) 2850 -> tensor<10x1x10xf32> { 2851 %c1 = arith.constant 1 : index 2852 %c10 = arith.constant 10 : index 2853 %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32> 2854 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10] 2855 : tensor<?x?xf32> into tensor<?x?x?xf32> 2856 %2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32> 2857 return %2 : tensor<10x1x10xf32> 2858} 2859// CHECK-LABEL: func.func @fold_expand_of_cast 2860// CHECK: %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10] 2861// CHECK: return %[[RES]] 2862 2863// ----- 2864 2865func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>) 2866 -> tensor<?x?x?xf32> { 2867 %c1 = arith.constant 1 : index 2868 %c10 = arith.constant 10 : index 2869 %0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32> 2870 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10] 2871 : tensor<?x?xf32> into tensor<?x?x?xf32> 2872 return %1 : tensor<?x?x?xf32> 2873} 2874// CHECK-LABEL: func.func @sink_expand_of_cast 2875// CHECK-DAG: %[[C10:.*]] = arith.constant 10 2876// CHECK-DAG: %[[C1:.*]] = arith.constant 1 2877// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] 2878// CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10] 2879// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]] 2880// CHECK: return %[[RES]] 2881 2882// ----- 2883 2884func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index) 2885 -> tensor<?x?x?xf32> { 2886 %c10 = arith.constant 10 : index 2887 %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32> 2888 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10] 2889 : tensor<?x?xf32> into tensor<?x?x?xf32> 2890 return %1 : tensor<?x?x?xf32> 2891} 2892// CHECK-LABEL: func.func @partial_sink_expand_of_cast 2893// CHECK: %[[CAST:.+]] = tensor.cast 2894// CHECK-SAME: tensor<10x10xf32> to tensor<?x10xf32> 2895// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] 2896// CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10] 2897// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]] 2898// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32> 2899// CHECK: return %[[RES]] 2900