1// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s 2 3// CHECK-LABEL: func @memref_cast( 4func.func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> { 5 %c0 = arith.constant 0 : index 6 %c1 = arith.constant 1 : index 7 %c8 = arith.constant 8 : index 8 %c16 = arith.constant 16 : index 9 %1 = memref.alloc (%b) : memref<?xi8> 10 %2 = memref.view %1[%c0][] : memref<?xi8> to memref<16x16xf32> 11 %3 = memref.cast %2 : memref<16x16xf32> to memref<?x?xf32> 12 13 // CHECK: linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>) 14 linalg.matmul ins(%3, %3: memref<?x?xf32>, memref<?x?xf32>) 15 outs(%3: memref<?x?xf32>) 16 return %3: memref<?x?xf32> 17} 18 19// ----- 20 21#accesses = [ 22 affine_map<(i) -> (i)> 23] 24 25#trait = { 26 indexing_maps = #accesses, 27 iterator_types = ["parallel"] 28} 29 30func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { 31 // memref<0x32> is expected to be dce'ed 32 memref.copy %arg0, %arg0 : memref<0xf32> to memref<0xf32> 33 34 // tensor<0xf32> cannot be dce'ed 35 %1 = linalg.generic #trait outs(%arg1 : tensor<0xf32>) { 36 ^bb(%0: f32) : 37 linalg.yield %0 : f32 38 } -> tensor<0xf32> 39 40 return %1: tensor<0xf32> 41} 42// CHECK-LABEL: @dce_zero_memref 43// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32> 44// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32> 45// CHECK-NOT: memref.copy 46// CHECK-NEXT: return %[[ARG1]] 47 48// ----- 49 50func.func @dce_self_linalg_copy(%arg0 : memref<?xf32>) { 51 linalg.copy ins(%arg0: memref<?xf32>) outs(%arg0: memref<?xf32>) 52 return 53} 54 55// CHECK-LABEL: @dce_self_linalg_copy 56// CHECK-NOT: copy 57 58// ----- 59 60// CHECK-LABEL: func @tensor.cast( 61func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>) 62 -> tensor<3x?xf32> 63{ 64 %ta = tensor.cast %a : tensor<3x4xf32> to tensor<?x?xf32> 65 %tb = tensor.cast %b : tensor<4x?xf32> to tensor<?x?xf32> 66 %tc = tensor.cast %c : tensor<3x?xf32> to tensor<?x?xf32> 67 68 // CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>) 69 // CHECK-SAME: outs({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32> 70 %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>) 71 outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32> 72 73 %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<3x?xf32> 74 75 return %1: tensor<3x?xf32> 76} 77 78// ----- 79 80// CHECK-LABEL: func @tensor.cast.unranked( 81func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : tensor<*xf32>) 82 -> tensor<*xf32> 83{ 84 // CHECK: tensor.cast 85 // CHECK: tensor.cast 86 // CHECK: tensor.cast 87 %ta = tensor.cast %a : tensor<*xf32> to tensor<?x?xf32> 88 %tb = tensor.cast %b : tensor<*xf32> to tensor<?x?xf32> 89 %tc = tensor.cast %c : tensor<*xf32> to tensor<?x?xf32> 90 91 // CHECK: linalg.matmul ins({{.*}}tensor<?x?xf32>, tensor<?x?xf32>) 92 // CHECK-SAME: outs({{.*}}tensor<?x?xf32>) -> tensor<?x?xf32> 93 %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>) 94 outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32> 95 96 // CHECK: tensor.cast 97 %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<*xf32> 98 99 return %1: tensor<*xf32> 100} 101 102// ----- 103 104// CHECK-LABEL: func @linalg_effects( 105func.func @linalg_effects( 106 %a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : tensor<?x?xf32>, 107 %d : memref<?x?xf32>, %e : memref<?x?xf32>, %f : memref<?x?xf32>) { 108 // CHECK-NOT: %{{.*}} = linalg.matmul 109 %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>) 110 outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32> 111 112 // CHECK: linalg.matmul 113 linalg.matmul ins(%d, %e : memref<?x?xf32>, memref<?x?xf32>) 114 outs(%f : memref<?x?xf32>) 115 return 116} 117 118// ----- 119 120#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 121func.func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) 122 -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) { 123 %c0 = arith.constant 0 : index 124 %c1 = arith.constant 1 : index 125 %c2 = arith.constant 2 : index 126 %0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> 127 %1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> 128 %2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32> 129 %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32> 130 %4, %5 = linalg.generic { 131 indexing_maps = [#map, #map, #map, #map], 132 iterator_types = ["parallel", "parallel", "parallel"] 133 } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) 134 outs(%3, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) { 135 ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32): 136 linalg.yield %arg3, %arg2 : f32, f32 137 } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) 138 return %4, %5 : tensor<?x?x?xf32>, tensor<?x?x?xf32> 139} 140// CHECK-LABEL: func @remove_no_op 141// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 142// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 143// CHECK: return %[[ARG1]], %[[ARG0]] 144 145// ----- 146 147#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 148func.func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>) 149 -> tensor<1x2x3xf32> { 150 %out = tensor.empty() : tensor<1x2x3xf32> 151 %g = linalg.generic { 152 indexing_maps = [#map, #map], 153 iterator_types = ["parallel", "parallel", "parallel"] 154 } ins(%arg0 : tensor<?x?x?xf32>) 155 outs(%out : tensor<1x2x3xf32>) { 156 ^bb0(%arg2 : f32, %arg3 : f32): 157 linalg.yield %arg2 : f32 158 } -> (tensor<1x2x3xf32>) 159 return %g : tensor<1x2x3xf32> 160} 161// CHECK-LABEL: func @remove_no_op_mismatched_types 162// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 163// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<1x2x3xf32> 164// CHECK: return %[[CAST]] 165 166// ----- 167 168#map = affine_map<() -> ()> 169func.func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> { 170 %out = tensor.empty() : tensor<f32> 171 %g = linalg.generic { 172 indexing_maps = [#map, #map], 173 iterator_types = [] 174 } ins(%arg0 : f32) 175 outs(%out : tensor<f32>) { 176 ^bb0(%arg2 : f32, %arg3 : f32): 177 linalg.yield %arg2 : f32 178 } -> (tensor<f32>) 179 return %g : tensor<f32> 180} 181// CHECK-LABEL: func @cant_fold_to_tensor_cast 182// CHECK: linalg.generic 183 184// ----- 185 186#map = affine_map<(d0, d1) -> (d0, d1)> 187func.func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> { 188 %c0 = arith.constant 0 : index 189 %c1 = arith.constant 1 : index 190 %cst = arith.constant 1.000000e+00 : f32 191 %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 192 %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 193 %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 194 cf.br ^bb1(%cst : f32) 195 196^bb1(%arg1 : f32): 197 %3 = linalg.generic 198 {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} 199 ins(%arg0 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) { 200 ^bb0(%arg2: f32, %arg3 : f32): 201 linalg.yield %arg1 : f32 202 } -> tensor<?x?xf32> 203 return %3 : tensor<?x?xf32> 204} 205// CHECK-LABEL: func @keep_not_noop 206// CHECK: %[[RESULT:.+]] = linalg.generic 207// CHECK: return %[[RESULT]] 208 209// ----- 210 211#map = affine_map<(d0, d1) -> (d0, d1)> 212func.func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) 213 -> (tensor<?x?xf32>, tensor<?x?xf32>) { 214 %c0 = arith.constant 0 : index 215 %c1 = arith.constant 1 : index 216 %cst = arith.constant 1.000000e+00 : f32 217 %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 218 %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 219 %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 220 cf.br ^bb1(%cst : f32) 221 222^bb1(%arg2 : f32): 223 %3:2 = linalg.generic 224 {indexing_maps = [#map, #map, #map, #map], 225 iterator_types = ["parallel", "parallel"]} 226 ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 227 outs(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>) { 228 ^bb0(%arg3: f32, %arg4 : f32, %arg5 : f32, %arg6 : f32): 229 linalg.yield %arg2, %arg4 : f32, f32 230 } -> (tensor<?x?xf32>, tensor<?x?xf32>) 231 return %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32> 232} 233// CHECK-LABEL: func @keep_not_noop 234// CHECK: %[[RESULT:.+]]:2 = linalg.generic 235// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 236 237// ----- 238 239#accesses = [ 240 affine_map<(i, j) -> (i, j)> 241] 242 243#trait = { 244 indexing_maps = #accesses, 245 iterator_types = ["parallel", "parallel"] 246} 247 248// CHECK-LABEL: func @dead_linalg_tensor 249// CHECK-NOT: linalg.fill 250// CHECK-NOT: linalg.matmul 251// CHECK-NOT: linalg.generic 252// CHECK-NOT: tensor.pad 253// CHECK: return 254func.func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>, 255 %arg2: tensor<?x?xf32>, %high : index) { 256 %c0_i32 = arith.constant 0 : i32 257 %c0 = arith.constant 0 : index 258 %cst = arith.constant 0.000000e+00 : f32 259 %0 = linalg.fill ins(%c0_i32 : i32) outs(%arg0 : tensor<7x7xi32>) -> tensor<7x7xi32> 260 %1 = linalg.matmul ins(%arg1, %arg1: tensor<7x7xf32>, tensor<7x7xf32>) 261 outs(%arg1: tensor<7x7xf32>) -> tensor<7x7xf32> 262 %2 = linalg.generic #trait outs(%arg0 : tensor<7x7xi32>) { 263 ^bb(%3: i32) : 264 linalg.yield %3 : i32 265 } -> tensor<7x7xi32> 266 %3 = tensor.pad %arg2 low[%c0, %c0] high[%high, %high] { 267 ^bb0(%arg9: index, %arg10: index): 268 tensor.yield %cst : f32 269 } : tensor<?x?xf32> to tensor<2x4xf32> 270 return 271} 272 273// ----- 274 275func.func @propagate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index, 276 %arg3 : index) -> tensor<?x?xf32> { 277 %c0 = arith.constant 0 : index 278 %c1 = arith.constant 1 : index 279 %c21 = arith.constant 21 : index 280 %c42 = arith.constant 42 : index 281 %0 = tensor.empty(%c21, %c42) : tensor<?x?xf32> 282 %1 = linalg.fill ins(%arg1 : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32> 283 %2 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 284 %3 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 285 %4 = tensor.insert_slice %arg0 into %1[%arg2, %arg3] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32> 286 return %4 : tensor<?x?xf32> 287} 288// CHECK-LABEL: func @propagate_casts 289// CHECK: %[[INIT:.+]] = tensor.empty 290// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] 291// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %{{.+}} into %[[FILL]] 292// CHECK: %[[RESULT:.+]] = tensor.cast %[[INSERTED]] 293// CHECK: return %[[RESULT]] 294 295// ----- 296 297// CHECK-LABEL: @self_copy 298func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) { 299 300// CHECK-NOT: memref.copy 301 memref.copy %arg0, %arg0 : memref<2x3x?x4xf32> to memref<2x3x?x4xf32> 302 303// CHECK: return 304 return 305} 306 307// ----- 308// CHECK-LABEL: func @fold_fill_reshape() 309func.func @fold_fill_reshape() -> tensor<6x4xf32> { 310 %zero = arith.constant 0.0 : f32 311 %empty = tensor.empty() : tensor<1x2x3x4xf32> 312 // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape 313 // CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32) 314 // CHECK-SAME: outs(%[[COLLAPSE]] : tensor<6x4xf32>) 315 %fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> 316 %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]] 317 : tensor<1x2x3x4xf32> into tensor<6x4xf32> 318 // CHECK: return %[[FILL]] : tensor<6x4xf32> 319 return %reshape : tensor<6x4xf32> 320} 321 322// ----- 323 324// CHECK: func @fold_fill_reshape_dynamic 325// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?xf32> 326func.func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> { 327 %zero = arith.constant 0.0 : f32 328 // CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] 329 %0 = linalg.fill ins(%zero : f32) outs(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> 330 // CHECK: %[[RESULT:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[RESHAPE]] 331 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]] 332 : tensor<?x?x?x?x?xf32> into tensor<?x?xf32> 333 // CHECK: return %[[RESULT]] 334 return %1 : tensor<?x?xf32> 335} 336 337// ----- 338// CHECK: func @fold_fill_extract 339// CHECK-SAME: %[[ARG0:.+]]: i1 340func.func @fold_fill_extract(%arg0 : i1) -> i1 { 341 %c0 = arith.constant 0 : index 342 %c1 = arith.constant 1 : index 343 344 %empty_dynamic = tensor.empty(%c1) : tensor<1x2x3x?xi1> 345 %filled = linalg.fill ins(%arg0 : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1> 346 347 %extracted = tensor.extract %filled[%c0, %c0, %c0, %c0] : tensor<1x2x3x?xi1> 348 349 // CHECK: return %[[ARG0]] 350 return %extracted : i1 351} 352 353// ----- 354 355func.func @fill_pack() -> tensor<24x32x16x16xf32> { 356 %dest = tensor.empty() : tensor<384x512xf32> 357 %cst = arith.constant 0.000000e+00 : f32 358 %0 = tensor.empty() : tensor<24x32x16x16xf32> 359 %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32> 360 %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32> 361 return %pack : tensor<24x32x16x16xf32> 362} 363// CHECK-LABEL: func.func @fill_pack 364// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32> 365// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]] 366// CHECK: return %[[FILL]] 367 368// ----- 369 370func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{ 371 %c0_i32 = arith.constant 0 : i32 372 %alloc = memref.alloc() : memref<1x1x8x4x4x8xi32> 373 %9 = tensor.empty() : tensor<1x1x16x64xi32> 374 %extracted_slice_15 = tensor.extract_slice %9[0, 0, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : tensor<1x1x16x64xi32> to tensor<1x1x16x64xi32> 375 %16 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_15 : tensor<1x1x16x64xi32>) -> tensor<1x1x16x64xi32> 376 %0 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x4x4x8xi32> to tensor<1x1x8x4x4x8xi32> 377 %pack_18 = tensor.pack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %0 : tensor<1x1x16x64xi32> -> tensor<1x1x8x4x4x8xi32> 378 return %pack_18 : tensor<1x1x8x4x4x8xi32> 379} 380 381// CHECK-LABEL: func.func @fill_pack_general 382// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x4x8xi32> 383// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[ALLOC]] 384// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[TENSOR]] 385// CHECK: return %[[FILL]] 386 387// ----- 388 389#map = affine_map<()[s0] -> (s0 ceildiv 16)> 390func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> { 391 %cst = arith.constant 0.000000e+00 : f32 392 %c0 = arith.constant 0 : index 393 %c1 = arith.constant 1 : index 394 %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> 395 %dim = tensor.dim %0, %c0 : tensor<?x?xf32> 396 %dim_0 = tensor.dim %0, %c1 : tensor<?x?xf32> 397 %1 = affine.apply #map()[%dim] 398 %2 = affine.apply #map()[%dim_0] 399 %3 = tensor.empty(%1, %2) : tensor<?x?x16x16xf32> 400 %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor<?x?xf32> -> tensor<?x?x16x16xf32> 401 return %pack : tensor<?x?x16x16xf32> 402} 403// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> 404// CHECK: func.func @dynamic_fill_pack 405// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] 406// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 407// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 408// CHECK: %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]] 409// CHECK: %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]] 410// CHECK: %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]] 411// CHECK: %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]] 412// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor<?x?x16x16xf32> 413// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]] 414// CHECK: return %[[FILL]] 415 416// ----- 417 418// CHECK: func @fold_self_copy 419func.func @fold_self_copy(%0 : memref<4x16xf32>) { 420// CHECK-NEXT: return 421 linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 422 affine_map<(d0, d1) -> (d0, d1)>], 423 iterator_types = ["parallel", "parallel"]} 424 ins(%0 : memref<4x16xf32>) 425 outs(%0 : memref<4x16xf32>) { 426 ^bb0(%arg4: f32, %arg5: f32): 427 linalg.yield %arg4 : f32 428 } 429 return 430} 431 432// ----- 433 434// CHECK-LABEL: func @fold_static_pad_fill 435// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 436// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32> 437// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]] 438// CHECK: return %[[FILL]] 439func.func @fold_static_pad_fill() -> tensor<412x276xf32> { 440 %f0 = arith.constant 0.0 : f32 441 %empty = tensor.empty() : tensor<400x273xf32> 442 %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<400x273xf32>) -> tensor<400x273xf32> 443 %pad = tensor.pad %fill low[4, 1] high[8, 2] { 444 ^bb0(%arg1: index, %arg2: index): 445 tensor.yield %f0 : f32 446 } : tensor<400x273xf32> to tensor<412x276xf32> 447 return %pad : tensor<412x276xf32> 448} 449 450// ----- 451 452// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 9)> 453// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 10)> 454// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 + 23)> 455// CHECK: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 32)> 456 457// CHECK: func @fold_dynamic_pad_fill 458// CHECK-SAME: %[[SRC:.+]]: tensor<8x?x16x32xf32>, %[[LOW0:.+]]: index, %[[LOW3:.+]]: index, %[[HIGH2:.+]]: index, %[[HIGH3:.+]]: index 459 460// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index 461// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 462// CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]] 463// CHECK: %[[DIM1:.+]] = tensor.dim %[[SRC]], %[[I1]] : tensor<8x?x16x32xf32> 464// CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]] 465// CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]] 466// CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]] 467// CHECK: %[[INIT:.+]] = tensor.empty(%[[S0]], %[[S1]], %[[S2]], %[[S3]]) : tensor<?x?x?x?xf32> 468// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]] 469// CHECK: return %[[FILL]] 470func.func @fold_dynamic_pad_fill(%empty: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor<?x?x?x?xf32> { 471 %f0 = arith.constant 0.0 : f32 472 %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x?x16x32xf32>) -> tensor<8x?x16x32xf32> 473 %pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] { 474 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): 475 tensor.yield %f0 : f32 476 } : tensor<8x?x16x32xf32> to tensor<?x?x?x?xf32> 477 return %pad : tensor<?x?x?x?xf32> 478} 479 480// ----- 481 482// CHECK-LABEL: func @no_fold_pad_fill_value_mismatch 483func.func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> { 484 %f0 = arith.constant 0.0 : f32 485 %f1 = arith.constant 1.0 : f32 486 %empty = tensor.empty() : tensor<400x273xf32> 487 %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<400x273xf32>) -> tensor<400x273xf32> 488 // CHECK: tensor.pad 489 %pad = tensor.pad %fill low[4, 1] high[8, 2] { 490 ^bb0(%arg1: index, %arg2: index): 491 tensor.yield %f1 : f32 492 } : tensor<400x273xf32> to tensor<412x276xf32> 493 return %pad : tensor<412x276xf32> 494} 495 496// ----- 497 498// Tests below verify whether static information is propagated through all the operands of generic op. 499// 1. If one of the inputs of generic op has static info and it has no cast source. 500// 2. If one of the inputs of generic op has static info and it is coming from tensr.cast operation. 501// 3. If one of the outputs of generic op has static info and it is coming from tenso.cast operation. 502#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 503// CHECK-LABEL: func @static_input_without_cast 504// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> { 505func.func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> { 506 %c0 = arith.constant 0 : index 507 %c1 = arith.constant 1 : index 508 %c2 = arith.constant 2 : index 509 %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> 510 %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> 511 %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> 512 %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32> 513 %4 = linalg.generic { 514 indexing_maps = [#map, #map, #map], 515 iterator_types = ["parallel", "parallel", "parallel"] 516 } ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<?x?x?xf32>) 517 outs(%3 : tensor<?x?x?xf32>) { 518 ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): 519 %9 = arith.addf %arg2, %arg3 : f32 520 linalg.yield %9 : f32 521 } -> (tensor<?x?x?xf32>) 522 %5 = tensor.cast %4 : tensor<?x?x?xf32> to tensor<2x3x4xf32> 523 return %5 : tensor<2x3x4xf32> 524 // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32> 525 // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic 526 // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) 527 // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) 528} 529 530// ----- 531 532#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 533// CHECK-LABEL: func @static_input_with_cast 534// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> { 535func.func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> { 536 %c0 = arith.constant 0 : index 537 %c1 = arith.constant 1 : index 538 %c2 = arith.constant 2 : index 539 %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> 540 %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> 541 %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> 542 %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32> 543 %4 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32> 544 %5 = linalg.generic { 545 indexing_maps = [#map, #map, #map], 546 iterator_types = ["parallel", "parallel", "parallel"] 547 } ins(%arg0, %4 : tensor<2x3x4xf32>, tensor<2x?x?xf32>) 548 outs(%3 : tensor<?x?x?xf32>) { 549 ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): 550 %9 = arith.addf %arg2, %arg3 : f32 551 linalg.yield %9 : f32 552 } -> (tensor<?x?x?xf32>) 553 %6 = tensor.cast %5 : tensor<?x?x?xf32> to tensor<2x3x4xf32> 554 return %6: tensor<2x3x4xf32> 555 // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32> 556 // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic 557 // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) 558 // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) 559} 560 561// ----- 562 563#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 564// CHECK-LABEL: func @static_output_with_cast 565// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 566func.func @static_output_with_cast(%arg0 : tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 567 %c0 = arith.constant 0 : index 568 %c1 = arith.constant 1 : index 569 %c2 = arith.constant 2 : index 570 %0 = tensor.dim %arg2, %c0 : tensor<2x3x4xf32> 571 %1 = tensor.dim %arg2, %c1 : tensor<2x3x4xf32> 572 %2 = tensor.dim %arg2, %c2 : tensor<2x3x4xf32> 573 %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32> 574 %4 = tensor.cast %3 : tensor<?x?x?xf32> to tensor<2x3x4xf32> 575 %5 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32> 576 %6 = linalg.generic { 577 indexing_maps = [#map, #map, #map], 578 iterator_types = ["parallel", "parallel", "parallel"] 579 } ins(%arg0, %5 : tensor<?x?x?xf32>, tensor<2x?x?xf32>) 580 outs(%4 : tensor<2x3x4xf32>) { 581 ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): 582 %9 = arith.addf %arg3, %arg4 : f32 583 linalg.yield %9 : f32 584 } -> (tensor<2x3x4xf32>) 585 return %6: tensor<2x3x4xf32> 586 // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32> 587 // CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32> 588 // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic 589 // CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) 590 // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) 591} 592 593// ----- 594 595// This test checks the folding of tensor.cast operation when the source value of cast 596// has more static information than the destination value. 597#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 598// CHECK-LABEL: func @cast_source 599// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 600func.func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 601 %c0 = arith.constant 0 : index 602 %c1 = arith.constant 1 : index 603 %c2 = arith.constant 2 : index 604 %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> 605 %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> 606 %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> 607 %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32> 608 %4 = tensor.cast %arg0 : tensor<2x3x4xf32> to tensor<2x?x?xf32> 609 %5 = tensor.cast %arg1 : tensor<2x3x4xf32> to tensor<2x?x?xf32> 610 %6 = linalg.generic { 611 indexing_maps = [#map, #map, #map], 612 iterator_types = ["parallel", "parallel", "parallel"] 613 } ins(%4, %5 : tensor<2x?x?xf32>, tensor<2x?x?xf32>) 614 outs(%3 : tensor<?x?x?xf32>) { 615 ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): 616 %9 = arith.addf %arg2, %arg3 : f32 617 linalg.yield %9 : f32 618 } -> (tensor<?x?x?xf32>) 619 %7 = tensor.cast %6 : tensor<?x?x?xf32> to tensor<2x3x4xf32> 620 return %7: tensor<2x3x4xf32> 621 // CHECK: %[[GENERIC_OP:.*]] = linalg.generic 622 // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) 623 // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) 624} 625 626// ----- 627 628#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 629// CHECK-LABEL: func @cast_dest 630// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<1x?x?xf32>, 631func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2: index, %arg3: index, %arg4: index) -> tensor<?x?x?xf32> { 632 %0 = tensor.empty(%arg2, %arg3, %arg4) : tensor<?x?x?xf32> 633 %1 = tensor.cast %arg1 : tensor<1x?x?xf32> to tensor<?x?x?xf32> 634 %2 = linalg.generic { 635 indexing_maps = [#map, #map, #map], 636 iterator_types = ["parallel", "parallel", "parallel"] 637 } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<1x?x?xf32>) 638 outs(%0 : tensor<?x?x?xf32>) { 639 ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): 640 %3 = arith.subf %arg5, %arg6 : f32 641 linalg.yield %3 : f32 642 } -> tensor<?x?x?xf32> 643 return %2 : tensor<?x?x?xf32> 644// CHECK: %[[GENERIC_OP:.*]] = linalg.generic 645// CHECK-SAME: ins(%{{.*}}, %[[ARG1]] : tensor<1x?x?xf32>, tensor<1x?x?xf32>) 646// CHECK-SAME: outs(%{{.*}} : tensor<1x?x?xf32>) 647// CHECK: tensor.cast %[[GENERIC_OP]] : tensor<1x?x?xf32> to tensor<?x?x?xf32> 648} 649 650// ----- 651 652// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)> 653// CHECK-LABEL: func @insert_pad_into_fill 654// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index) 655// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 656// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 657// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index 658// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 659// CHECK: %[[INIT:.+]] = tensor.empty() 660// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]] 661// CHECK: %[[OFFSET1:.+]] = affine.apply #[[$MAP]]()[%[[LOW1]]] 662// CHECK: %[[D0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x?xf32> 663// CHECK: %[[D1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x?xf32> 664// CHECK: %[[D2:.+]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<?x?x?xf32> 665// CHECK: tensor.insert_slice %[[INPUT]] into %[[FILL]][%[[LOW0]], %[[OFFSET1]], 2] [%[[D0]], %[[D1]], %[[D2]]] [1, 1, 1] 666func.func @insert_pad_into_fill(%input: tensor<?x?x?xf32>, %low0: index, %low1: index, %high1: index, %high2: index) -> tensor<8x384x384xf32> { 667 %f0 = arith.constant 0.0 : f32 668 %c0 = arith.constant 0 : index 669 %pad = tensor.pad %input low[%low0, %low1, %c0] high[%c0, %high1, %high2] { 670 ^bb0(%arg3: index, %arg4: index, %arg5: index): 671 tensor.yield %f0 : f32 672 } : tensor<?x?x?xf32> to tensor<8x128x128xf32> 673 %empty = tensor.empty() : tensor<8x384x384xf32> 674 %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> 675 %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 676 return %0: tensor<8x384x384xf32> 677} 678 679// ----- 680 681// CHECK-LABEL: func @multi_insert_pad_into_fill 682// CHECK-SAME: (%[[INPUT:.+]]: tensor<7x123x124xf32>, %[[A:.+]]: tensor<8x128x128xf32>, %[[OFFSET:.+]]: index) 683// CHECK: %[[FILL:.+]] = linalg.fill 684// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[A]] into %[[FILL]][%[[OFFSET]], 0, 0] [8, 128, 128] [1, 1, 1] 685// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[A]] into %[[INSERT0]][0, 128, %[[OFFSET]]] [8, 128, 128] [1, 1, 1] 686// CHECK: tensor.insert_slice %[[INPUT]] into %[[INSERT1]][1, 2, 256] [7, 123, 124] [1, 1, 1] 687func.func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { 688 %f0 = arith.constant 0.0 : f32 689 %c0 = arith.constant 0 : index 690 %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { 691 ^bb0(%arg3: index, %arg4: index, %arg5: index): 692 tensor.yield %f0 : f32 693 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> 694 %empty = tensor.empty() : tensor<8x384x384xf32> 695 %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> 696 %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 697 %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 698 %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 699 return %2: tensor<8x384x384xf32> 700} 701 702// ----- 703 704// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap 705func.func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { 706 %f0 = arith.constant 0.0 : f32 707 %c0 = arith.constant 0 : index 708 // CHECK: tensor.pad 709 %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { 710 ^bb0(%arg3: index, %arg4: index, %arg5: index): 711 tensor.yield %f0 : f32 712 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> 713 %empty = tensor.empty() : tensor<8x384x384xf32> 714 %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> 715 %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 716 %1 = tensor.insert_slice %a into %0 [0, 0, 129] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 717 // Range overlap with %1 at dim#3 718 %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 719 return %2: tensor<8x384x384xf32> 720} 721 722// ----- 723 724// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap 725func.func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { 726 %f0 = arith.constant 0.0 : f32 727 %c0 = arith.constant 0 : index 728 // CHECK: tensor.pad 729 %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { 730 ^bb0(%arg3: index, %arg4: index, %arg5: index): 731 tensor.yield %f0 : f32 732 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> 733 %empty = tensor.empty() : tensor<8x384x384xf32> 734 %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> 735 %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 736 %1 = tensor.insert_slice %a into %0 [0, 128, 255] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 737 // Range overlap with %0 at dim#3 738 %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 739 return %2: tensor<8x384x384xf32> 740} 741 742// ----- 743 744// CHECK-LABEL: func @multi_insert_pad_into_fill 745func.func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { 746 %f0 = arith.constant 0.0 : f32 747 %c0 = arith.constant 0 : index 748 // CHECK-NOT: tensor.pad 749 %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { 750 ^bb0(%arg3: index, %arg4: index, %arg5: index): 751 tensor.yield %f0 : f32 752 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> 753 %empty = tensor.empty() : tensor<8x384x384xf32> 754 %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> 755 // Overlap btween %0 and %1 is fine but not with %2 is fine. 756 // CHECK-COUNT-3: tensor.insert_slice 757 %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 758 %1 = tensor.insert_slice %a into %0 [0, 1, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 759 %2 = tensor.insert_slice %pad into %1 [0, 256, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 760 return %2: tensor<8x384x384xf32> 761} 762 763// ----- 764 765// CHECK-LABEL: func @multi_insert_pad_into_fill_mismatch 766func.func @multi_insert_pad_into_fill_mismatch(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { 767 %f0 = arith.constant 0.0 : f32 768 %f1 = arith.constant 1.0 : f32 769 %c0 = arith.constant 0 : index 770 // CHECK: tensor.pad 771 %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { 772 ^bb0(%arg3: index, %arg4: index, %arg5: index): 773 tensor.yield %f0 : f32 774 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> 775 %empty = tensor.empty() : tensor<8x384x384xf32> 776 // Different filling value than padding value. 777 %fill = linalg.fill ins(%f1 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> 778 %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 779 %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 780 %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> 781 return %2: tensor<8x384x384xf32> 782} 783 784// ----- 785 786func.func @fold_linalgop_with_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, 787 %arg2 : tensor<?x?xf32>) -> (tensor<4x8xf32>, tensor<?x?xf32>) { 788 %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 789 outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> 790 %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32> 791 return %1, %0 : tensor<4x8xf32>, tensor<?x?xf32> 792} 793// CHECK: func @fold_linalgop_with_cast_consumer( 794// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 795// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32> 796// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>) 797// CHECK-DAG: %[[LHS_CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32> 798// CHECK-DAG: %[[RHS_CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<?x8xf32> 799// CHECK-DAG: %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor<?x?xf32> to tensor<4x8xf32> 800// CHECK: %[[MATMUL:.+]] = linalg.matmul 801// CHECK-SAME: ins(%[[LHS_CAST]], %[[RHS_CAST]] : 802// CHECK-SAME: outs(%[[OUT_CAST]] : 803// CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]] 804// CHECK: return %[[MATMUL]], %[[RESULT_CAST]] 805 806// ----- 807 808func.func private @some_use(%0 : tensor<4x8xf32>) 809 810func.func @linalgop_with_cond_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, 811 %arg2 : tensor<?x?xf32>, %arg3 : i1) -> tensor<?x?xf32> { 812 %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 813 outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> 814 scf.if %arg3 { 815 %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32> 816 func.call @some_use(%1) : (tensor<4x8xf32>) -> () 817 } 818 return %0 : tensor<?x?xf32> 819} 820 821// Check conditionally reachable cast is not folded into producer. 822// CHECK-LABEL: func @linalgop_with_cond_cast_consumer 823// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: tensor<?x?xf32>, %[[ARG3:.*]]: i1) 824// CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) 825// CHECK-SAME: outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> 826// CHECK: scf.if %[[ARG3]] { 827// CHECK: %[[CAST:.*]] = tensor.cast %[[RES]] : tensor<?x?xf32> to tensor<4x8xf32> 828// CHECK: func.call @some_use(%[[CAST]]) : (tensor<4x8xf32>) -> () 829// CHECK: } 830// CHECK: return %[[RES]] : tensor<?x?xf32> 831 832 833// ----- 834 835func.func @fold_conv_op_with_cast_consumer(%arg0 : tensor<?x?x?x?xf32>, 836 %arg1 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>) -> 837 (tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>) { 838 %0 = linalg.conv_2d_nchw_fchw ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) 839 outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> 840 %1 = tensor.cast %0 : tensor<?x?x?x?xf32> to tensor<4x8x12x16xf32> 841 return %1, %0 : tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32> 842} 843// CHECK: func @fold_conv_op_with_cast_consumer( 844// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32> 845// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32> 846// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>) 847// CHECK: %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor<?x?x?x?xf32> to tensor<4x8x12x16xf32> 848// CHECK: %[[CONV:.+]] = linalg.conv_2d_nchw_fchw 849// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : 850// CHECK-SAME: outs(%[[OUT_CAST]] : 851// CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[CONV]] 852// CHECK: return %[[CONV]], %[[RESULT_CAST]] 853 854// ----- 855 856func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<2x3x4xf32>) { 857 %c0 = arith.constant 0 : index 858 %c1 = arith.constant 1 : index 859 %c2 = arith.constant 2 : index 860 %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> 861 %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> 862 %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32> 863 %empty1 = tensor.empty(%d1, %d2, %d0) : tensor<?x?x?xf32> 864 %empty2 = tensor.empty(%d2, %d1, %d0) : tensor<?x?x?xf32> 865 %0:2 = linalg.generic { 866 iterator_types = ["parallel", "parallel", "parallel"], 867 indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 868 affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 869 affine_map<(d0, d1, d2) -> (d2, d1, d0)>]} 870 ins(%arg0 : tensor<?x?x?xf32>) outs(%empty1, %empty2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) { 871 ^bb0(%b0 : f32, %b1 : f32, %b2 : f32) : 872 linalg.yield %b0, %b0 : f32, f32 873 } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) 874 %1 = tensor.cast %0#1 : tensor<?x?x?xf32> to tensor<2x3x4xf32> 875 return %0#0, %1 : tensor<?x?x?xf32>, tensor<2x3x4xf32> 876} 877// CHECK: func @fold_multi_use_generic_op_with_consumer 878// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> 879// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<2x3x4xf32> 880// CHECK-DAG: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x3x2xf32> 881// CHECK-DAG: %[[INIT2:.+]] = tensor.empty() : tensor<3x2x4xf32> 882// CHECK: %[[GENERIC:.+]]:2 = linalg.generic 883// CHECK-SAME: ins(%[[CAST]] : 884// CHECK-SAME: outs(%[[INIT2]], %[[INIT1]] : 885// CHECK: %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 : tensor<3x2x4xf32> to tensor<?x?x?xf32> 886// CHECK: return %[[RETURN_CAST]], %[[GENERIC]]#1 887 888// ----- 889 890#map = affine_map<(d0) -> (d0)> 891func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) { 892 linalg.generic { 893 indexing_maps = [#map, #map], 894 iterator_types = ["parallel"] 895 } ins(%arg0 : memref<?xf32>) 896 outs(%arg1 : memref<?xf32>) { 897 ^bb0(%arg2 : f32, %arg3 : f32): 898 linalg.yield %arg2 : f32 899 } 900 return 901} 902 903// Do not erase ops with buffer semantics. 904// CHECK-LABEL: func @identity_buffer 905// CHECK-SAME: (%[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32>) 906// CHECK: linalg.generic { 907// CHECK-SAME: indexing_maps = [#map, #map], 908// CHECK-SAME: iterator_types = ["parallel"] 909// CHECK-SAME: } ins(%[[ARG1]] : memref<?xf32>) 910// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) { 911 912// ----- 913 914#map = affine_map<(d0, d1) -> (d1, d0)> 915func.func @erase_non_identity_noop(%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { 916 %0 = linalg.generic { 917 indexing_maps = [#map, #map], 918 iterator_types = ["parallel", "parallel"] 919 } ins(%arg0 : tensor<?x?xf32>) 920 outs(%arg1 : tensor<?x?xf32>) { 921 ^bb0(%in: f32, %out: f32): 922 linalg.yield %in: f32 923 } -> tensor<?x?xf32> 924 return %0 : tensor<?x?xf32> 925} 926 927// Do not erase ops with buffer semantics. 928// CHECK-LABEL: func @erase_non_identity_noop 929// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) 930// CHECK: return %[[ARG0]] : tensor<?x?xf32> 931 932// ----- 933 934// Just make sure that we don't crash. 935 936// CHECK-LABEL: func @dedeplicate_regression_test 937func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: tensor<4xf32>) { 938 %36 = linalg.generic 939 {indexing_maps = [affine_map<(d0) -> (d0)>, 940 affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], 941 iterator_types = ["parallel"]} 942 ins(%1, %1 : tensor<4xf32>, tensor<4xf32>) 943 outs(%0 : tensor<4xf32>) { 944 ^bb0(%in: f32, %in_24: f32, %out: f32): 945 linalg.yield %in : f32 946 } -> tensor<4xf32> 947 %53 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], 948 iterator_types = ["parallel"]} 949 outs(%36 : tensor<4xf32>) { 950 ^bb0(%out: f32): 951 linalg.yield %out : f32 952 } -> tensor<4xf32> 953 return 954} 955 956// ----- 957 958// CHECK-LABEL: dead_softmax 959func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { 960 %0 = tensor.empty() : tensor<16x64x256xf32> 961 // CHECK-NOT: linalg.softmax 962 %1 = linalg.softmax dimension(1) 963 ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32> 964 return %arg0 : tensor<16x64x256xf32> 965} 966 967// ----- 968 969// CHECK-LABEL: func @canonicalize_dim_of_dest_style_op 970// CHECK: tensor.dim 971// CHECK: tensor.dim 972// CHECK-NOT: tensor.dim 973// CHECK: return 974func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> { 975 %c0 = arith.constant 0 : index 976 %c1 = arith.constant 1 : index 977 %dim0_0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 978 %dim1_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 979 %0 = tensor.empty(%dim0_0, %dim1_0) : tensor<?x?xf32> 980 %1 = linalg.copy ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32> 981 %dim0_1 = tensor.dim %1, %c0 : tensor<?x?xf32> 982 %dim1_1 = tensor.dim %1, %c1 : tensor<?x?xf32> 983 %2 = tensor.empty(%dim0_1, %dim1_1) : tensor<?x?xf32> 984 %3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32> 985 return %3: tensor<?x?xf32> 986} 987// ----- 988 989// CHECK-LABEL: func @canonicalize_fill_to_copy_input( 990// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 991// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>) 992// CHECK: %[[ZERO:.+]] = arith.constant 0.0 993// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>) 994func.func @canonicalize_fill_to_copy_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> { 995 %c0 = arith.constant 0.0 : f32 996 %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> 997 %copy = linalg.copy ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> 998 return %copy : tensor<?x?xf32> 999} 1000 1001// ----- 1002 1003// CHECK-LABEL: func @canonicalize_fill_to_copy_dest( 1004// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 1005// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>) 1006// CHECK: linalg.copy ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[ARG0]] : tensor<?x?xf32>) 1007func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> { 1008 %c0 = arith.constant 0.0 : f32 1009 %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> 1010 %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32> 1011 return %copy : tensor<?x?xf32> 1012} 1013 1014// ----- 1015 1016// CHECK-LABEL: func @canonicalize_fill_to_transpose_input( 1017// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> 1018// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>) 1019// CHECK: %[[ZERO:.+]] = arith.constant 0.0 1020// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>) 1021func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> { 1022 %c0 = arith.constant 0.0 : f32 1023 %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> 1024 %transpose = linalg.transpose ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0] 1025 return %transpose : tensor<?x?xf32> 1026} 1027 1028// ----- 1029 1030// CHECK-LABEL: func @broadcast_same_shape( 1031// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32> 1032// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>) 1033// CHECK-NOT: linalg.broadcast 1034// CHECK: return %[[ARG0]] : tensor<2x3xf32> 1035func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) -> tensor<2x3xf32> { 1036 %0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = [] 1037 return %0 : tensor<2x3xf32> 1038} 1039 1040// ----- 1041 1042func.func @transpose_1d(%input: tensor<16xf32>, 1043 %init: tensor<16xf32>) -> tensor<16xf32> { 1044 %transpose = linalg.transpose 1045 ins(%input:tensor<16xf32>) 1046 outs(%init:tensor<16xf32>) 1047 permutation = [0] 1048 func.return %transpose : tensor<16xf32> 1049} 1050 1051// CHECK-LABEL: func @transpose_1d( 1052// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>, 1053// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>) 1054// CHECK-NOT: linalg.transpose 1055// CHECK: return %[[INPUT]] : tensor<16xf32> 1056 1057// ----- 1058 1059func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>, 1060 %init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> { 1061 %transpose = linalg.transpose 1062 ins(%input:tensor<16x32x64xf32>) 1063 outs(%init:tensor<16x32x64xf32>) 1064 permutation = [0, 1, 2] 1065 func.return %transpose : tensor<16x32x64xf32> 1066} 1067 1068// CHECK-LABEL: func @transpose_identity_perm( 1069// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>, 1070// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>) 1071// CHECK-NOT: linalg.transpose 1072// CHECK: return %[[INPUT]] : tensor<16x32x64xf32> 1073 1074// ----- 1075 1076func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>, 1077 %init1: tensor<4x3x5xf32>, 1078 %init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> { 1079 // CHECK-LABEL: @transpose_transpose_cancel 1080 // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32> 1081 // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32> 1082 // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<5x4x3xf32> 1083 // CHECK-NOT: linalg.transpose 1084 // CHECK: return %[[INPUT]] : tensor<5x4x3xf32> 1085 %transpose1 = linalg.transpose 1086 ins(%input:tensor<5x4x3xf32>) 1087 outs(%init1:tensor<4x3x5xf32>) 1088 permutation = [1, 2, 0] 1089 %transpose2 = linalg.transpose 1090 ins(%transpose1:tensor<4x3x5xf32>) 1091 outs(%init2:tensor<5x4x3xf32>) 1092 permutation = [2, 0, 1] 1093 func.return %transpose2 : tensor<5x4x3xf32> 1094} 1095 1096// ----- 1097 1098func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>, 1099 %init1: tensor<4x3x5xf32>, 1100 %init2: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { 1101 // CHECK-LABEL: @transpose_transpose_fold 1102 // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32> 1103 // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32> 1104 // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<3x4x5xf32> 1105 // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<5x4x3xf32>) outs(%[[INIT2]] : tensor<3x4x5xf32>) permutation = [2, 1, 0] 1106 // CHECK-NOT: linalg.transpose 1107 // CHECK: return %[[TRANSPOSE]] : tensor<3x4x5xf32> 1108 %transpose1 = linalg.transpose 1109 ins(%input:tensor<5x4x3xf32>) 1110 outs(%init1:tensor<4x3x5xf32>) 1111 permutation = [1, 2, 0] 1112 %transpose2 = linalg.transpose 1113 ins(%transpose1:tensor<4x3x5xf32>) 1114 outs(%init2:tensor<3x4x5xf32>) 1115 permutation = [1, 0, 2] 1116 func.return %transpose2 : tensor<3x4x5xf32> 1117} 1118 1119// ----- 1120 1121func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>, 1122 %init1: tensor<1x2x3x4x5x6xf32>, 1123 %init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> { 1124 // CHECK-LABEL: @broadcast_transpose_fold 1125 // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32> 1126 // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32> 1127 // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32> 1128 // CHECK: %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32> 1129 // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1] 1130 // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1] 1131 // CHECK: return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32> 1132 %broadcast = linalg.broadcast 1133 ins(%input : tensor<2x4x5xf32>) 1134 outs(%init1 : tensor<1x2x3x4x5x6xf32>) 1135 dimensions = [0, 2, 5] 1136 %transpose = linalg.transpose 1137 ins(%broadcast : tensor<1x2x3x4x5x6xf32>) 1138 outs(%init2 : tensor<1x6x2x3x5x4xf32>) 1139 permutation = [0, 5, 1, 2, 4, 3] 1140 func.return %transpose : tensor<1x6x2x3x5x4xf32> 1141} 1142 1143// ----- 1144 1145func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>, 1146 %init1: tensor<1x?x3x?x5x6xf32>, 1147 %init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> { 1148 // CHECK-LABEL: @broadcast_transpose_fold_dynamic 1149 // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32> 1150 // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32> 1151 // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32> 1152 // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 1153 // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 1154 // CHECK: %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32> 1155 // CHECK: %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32> 1156 // CHECK: %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32> 1157 // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0] 1158 // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3] 1159 // CHECK: return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32> 1160 %broadcast = linalg.broadcast 1161 ins(%input : tensor<?x?x5xf32>) 1162 outs(%init1 : tensor<1x?x3x?x5x6xf32>) 1163 dimensions = [0, 2, 5] 1164 %transpose = linalg.transpose 1165 ins(%broadcast : tensor<1x?x3x?x5x6xf32>) 1166 outs(%init2 : tensor<1x3x?x6x5x?xf32>) 1167 permutation = [0, 2, 3, 5, 4, 1] 1168 func.return %transpose : tensor<1x3x?x6x5x?xf32> 1169} 1170 1171// ----- 1172 1173func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>, 1174 %init1: tensor<2x4xf32>, 1175 %init2: tensor<4x2xf32>) -> tensor<4x2xf32> { 1176 // CHECK-LABEL: @broadcast_transpose_fold_2dim 1177 // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> 1178 // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32> 1179 // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32> 1180 // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0] 1181 // CHECK: return %[[BROADCAST]] : tensor<4x2xf32> 1182 %broadcast = linalg.broadcast 1183 ins(%input : tensor<2xf32>) 1184 outs(%init1 : tensor<2x4xf32>) 1185 dimensions = [1] 1186 %transpose = linalg.transpose 1187 ins(%broadcast : tensor<2x4xf32>) 1188 outs(%init2 : tensor<4x2xf32>) 1189 permutation = [1, 0] 1190 func.return %transpose : tensor<4x2xf32> 1191} 1192 1193// ----- 1194 1195func.func @concats_of_fill( 1196 %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) 1197 -> tensor<5x?x?xf32> 1198{ 1199 %cst0 = arith.constant 0.0 : f32 1200 %cst1 = arith.constant 0.0 : f32 1201 %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32> 1202 %1 = linalg.fill ins(%cst0 : f32) outs(%0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> 1203 %2 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32> 1204 %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> 1205 %4 = tensor.concat dim(1) %1, %3 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32> 1206 return %4 : tensor<5x?x?xf32> 1207} 1208// CHECK: func @concats_of_fill( 1209// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, 1210// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, 1211// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, 1212// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index) 1213// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0 1214// CHECK-DAG: %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]]) 1215// CHECK-DAG: %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) 1216// CHECK: %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]] 1217// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] : 1218// CHECK: return %[[FILL]] 1219 1220// ----- 1221 1222func.func @transpose_buffer(%input: memref<?xf32>, 1223 %init: memref<?xf32>) { 1224 linalg.transpose ins(%input:memref<?xf32>) 1225 outs(%init:memref<?xf32>) 1226 permutation = [0] 1227 func.return 1228} 1229 1230// CHECK-LABEL: func.func @transpose_buffer( 1231// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>, 1232// CHECK-SAME: %[[VAL_1:.*]]: memref<?xf32>) { 1233// CHECK: linalg.transpose ins(%[[VAL_0]] : memref<?xf32>) 1234// CHECK-SAME: outs(%[[VAL_1]] : memref<?xf32>) permutation = [0] 1235 1236// ----- 1237 1238// This test checks linalg op has a recursive memory effect. Otherwise 1239// linalg.map without a user would be DCEd. 1240func.func @recursive_effect(%arg : tensor<1xf32>) { 1241 %init = arith.constant dense<0.0> : tensor<1xf32> 1242 %mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>) 1243 (%in : f32) { 1244 vector.print %in : f32 1245 linalg.yield %in : f32 1246 } 1247 func.return 1248} 1249 1250// CHECK-LABEL: @recursive_effect 1251// CHECK: linalg.map 1252