1// RUN: mlir-opt %s --one-shot-bufferize="dialect-filter=tensor,bufferization copy-before-write unknown-type-conversion=identity-layout-map" -cse -split-input-file | FileCheck %s 2 3// CHECK-LABEL: func @dim( 4// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>, 5// CHECK-SAME: %[[INDEX:.*]]: index) -> index { 6// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<*xf32> to memref<*xf32> 7// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<*xf32> 8// CHECK: return %[[EXTENT]] : index 9func.func @dim(%arg0: tensor<*xf32>, %arg1: index) -> index { 10 %0 = tensor.dim %arg0, %arg1 : tensor<*xf32> 11 return %0 : index 12} 13 14// ----- 15 16// CHECK-LABEL: func @rank( 17// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> index { 18// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] 19// CHECK: %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32> 20func.func @rank(%arg0: tensor<*xf32>) -> index { 21 %0 = tensor.rank %arg0 : tensor<*xf32> 22 return %0 : index 23} 24 25// ----- 26 27// CHECK-LABEL: func @tensor.cast( 28// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> { 29// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] 30// CHECK: %[[CASTED:.*]] = memref.cast %[[MEMREF]] : memref<?xindex> to memref<2xindex> 31// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED]] 32// CHECK: return %[[RET]] : tensor<2xindex> 33func.func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> { 34 %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex> 35 return %0 : tensor<2xindex> 36} 37 38// ----- 39 40// CHECK-LABEL: func @tensor.cast_from_unranked( 41// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> { 42// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<*xf32> to memref<*xf32> 43// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32, strided<[?], offset: ?>> 44// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32, strided<[?], offset: ?>> 45// CHECK: return %[[RET]] : tensor<2xf32> 46func.func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> { 47 %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32> 48 return %0 : tensor<2xf32> 49} 50 51// ----- 52 53// CHECK-LABEL: func @tensor.cast_to_unranked( 54// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> { 55// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<2xf32> to memref<2xf32> 56// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<2xf32> to memref<*xf32> 57// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<*xf32> 58// CHECK: return %[[RET]] : tensor<*xf32> 59func.func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { 60 %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32> 61 return %0 : tensor<*xf32> 62} 63 64// ----- 65 66// CHECK-LABEL: func @tensor.empty( 67// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<5xf32> 68// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<5xf32> 69// CHECK: return %[[RET]] : tensor<5xf32> 70func.func @tensor.empty() -> tensor<5xf32> { 71 %0 = tensor.empty() : tensor<5xf32> 72 return %0 : tensor<5xf32> 73} 74 75// ----- 76 77// CHECK-LABEL: func @tensor.extract( 78// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>, 79// CHECK-SAME: %[[IDX:.*]]: index) -> f32 { 80// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<?xf32> to memref<?xf32> 81// CHECK: %[[RET:.*]] = memref.load %[[MEMREF]][%[[IDX]]] : memref<?xf32> 82// CHECK: return %[[RET]] : f32 83// CHECK: } 84func.func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 { 85 %0 = tensor.extract %arg0[%arg1] : tensor<?xf32> 86 return %0 : f32 87} 88 89// ----- 90 91// CHECK-LABEL: func @tensor.from_elements_0d( 92// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> { 93// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index> 94// CHECK: store %[[ELEM0]], %[[MEMREF]] 95// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 96// CHECK: return %[[RET]] : tensor<index> 97func.func @tensor.from_elements_0d(%arg0: index) -> tensor<index> { 98 %0 = tensor.from_elements %arg0 : tensor<index> 99 return %0 : tensor<index> 100} 101 102// ----- 103 104// CHECK-LABEL: func @tensor.from_elements_1d( 105// CHECK-SAME: %[[ELEM0:.*]]: index, 106// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { 107// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 108// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 109// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex> 110// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] 111// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] 112// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 113// CHECK: return %[[RET]] : tensor<2xindex> 114func.func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> { 115 %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex> 116 return %0 : tensor<2xindex> 117} 118 119// ----- 120 121// CHECK-LABEL: func @tensor.from_elements_2d( 122// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index) 123// CHECK-SAME: -> tensor<3x2xindex> { 124// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 125// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 126// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 127// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex> 128// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] 129// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] 130// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] 131// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]] 132// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]] 133// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]] 134// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 135// CHECK: return %[[RET]] : tensor<3x2xindex> 136func.func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> { 137 %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1 138 : tensor<3x2xindex> 139 return %0 : tensor<3x2xindex> 140} 141 142// ----- 143 144// CHECK-LABEL: func @tensor.from_elements_3d( 145// CHECK-SAME: %[[F0:.*]]: f32 146 147// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 148// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 149// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 150// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0 151// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0 152// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0 153// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0 154// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0 155// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0 156// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 157// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 158 159// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 160// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 161// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 162 163// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32> 164 165// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]] 166// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]] 167// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]] 168// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]] 169// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]] 170// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]] 171// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]] 172// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]] 173// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]] 174// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]] 175// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]] 176// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]] 177 178// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 179// CHECK: return %[[RET]] : tensor<3x2x2xf32> 180func.func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> { 181 %f1 = arith.constant 1.0 : f32 182 %f2 = arith.constant 2.0 : f32 183 %f3 = arith.constant 3.0 : f32 184 %f4 = arith.constant 4.0 : f32 185 %f5 = arith.constant 5.0 : f32 186 %f6 = arith.constant 6.0 : f32 187 %f7 = arith.constant 7.0 : f32 188 %f8 = arith.constant 8.0 : f32 189 %f9 = arith.constant 9.0 : f32 190 %f10 = arith.constant 10.0 : f32 191 %f11 = arith.constant 11.0 : f32 192 %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11 193 : tensor<3x2x2xf32> 194 return %0 : tensor<3x2x2xf32> 195} 196 197// ----- 198 199// CHECK-LABEL: func @tensor.generate( 200// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, 201// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> { 202// CHECK-DAG: %[[ARG_M:.*]] = bufferization.to_memref %[[ARG]] : tensor<*xf32> to memref<*xf32> 203// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex> 204// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] 205// CHECK: %[[MAPPED:.*]] = linalg.map 206// CHECK: outs(%[[ALLOC_T]] : tensor<?xindex>) 207// CHECK: %[[INDEX:.*]] = linalg.index 0 : index 208// CHECK: %[[ELEM:.*]] = memref.dim %[[ARG_M]], %[[INDEX]] : memref<*xf32> 209// CHECK: linalg.yield %[[ELEM]] 210// CHECK: } 211// CHECK: return %[[MAPPED]] : tensor<?xindex> 212// CHECK: } 213func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> { 214 %result = tensor.generate %dynamic_extent { 215 ^bb0(%i : index): 216 %elem = tensor.dim %arg, %i : tensor<*xf32> 217 tensor.yield %elem : index 218 } : tensor<?xindex> 219 return %result : tensor<?xindex> 220} 221 222// ----- 223 224// Additional test that checks the logic for intermixed static and dynamic 225// extents. 226// 227// CHECK-LABEL: func @tensor.generate_static_and_dynamic( 228// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { 229// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex> 230// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] 231// CHECK: %[[MAPPED:.*]] = linalg.map 232// CHECK: outs(%[[ALLOC_T]] : tensor<16x?xindex>) 233// CHECK: %[[INDEX0:.*]] = linalg.index 0 234// CHECK: %[[INDEX1:.*]] = linalg.index 1 235// CHECK: %[[ADD:.*]] = arith.addi %[[INDEX0]], %[[INDEX1]] 236// CHECK: linalg.yield %[[ADD]] 237// CHECK: } 238// CHECK: return %[[MAPPED]] : tensor<16x?xindex> 239// CHECK: } 240func.func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { 241 %result = tensor.generate %arg0 { 242 ^bb0(%i: index, %j: index): 243 %sum = arith.addi %i, %j : index 244 tensor.yield %sum : index 245 } : tensor<16x?xindex> 246 return %result : tensor<16x?xindex> 247} 248 249// ----- 250 251// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body 252func.func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> { 253 // CHECK-NOT: tensor.generate 254 %tensor = tensor.generate %arg0 { 255 ^bb0(%iv: index): 256 // CHECK: test.source 257 %0 = "test.source"() : () -> index 258 tensor.yield %0 : index 259 } : tensor<?xindex> 260 return %tensor : tensor<?xindex> 261} 262 263// ----- 264 265// CHECK-LABEL: func @tensor.extract_slice( 266// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index 267func.func @tensor.extract_slice( 268 %t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> { 269 // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : tensor<?x?xf32> to memref<?x?xf32> 270 // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, strided<[?, 1], offset: ?>> 271 %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1] 272 : tensor<?x?xf32> to tensor<?x10xf32> 273 // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] 274 // CHECK: return %[[r_tensor]] 275 return %0 : tensor<?x10xf32> 276} 277 278// ----- 279 280// CHECK-LABEL: func @tensor.extract_slice_rank_reducing( 281// CHECK-SAME: %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index, 282// CHECK-SAME: %[[idx2:.*]]: index 283func.func @tensor.extract_slice_rank_reducing( 284 %t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> { 285 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?x10x?xf32> to memref<?x10x?xf32> 286 // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, strided<[?, 1], offset: ?>> 287 %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1] 288 : tensor<?x10x?xf32> to tensor<?x15xf32> 289 // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] 290 // CHECK: return %[[r_tensor]] 291 return %0 : tensor<?x15xf32> 292} 293 294// ----- 295 296// CHECK-LABEL: func @tensor.insert_slice( 297// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>, 298// CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index 299func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>, 300 %idx1: index, %idx2: index) -> tensor<?x?xf32> { 301 // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 302 // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index 303 // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?x?xf32> to memref<?x?xf32> 304 // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : tensor<?x10xf32> to memref<?x10xf32> 305 // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] 306 // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] 307 // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]]) 308 // CHECK: memref.copy %[[m1]], %[[alloc]] 309 // CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1] 310 // CHECK: memref.copy %[[m2]], %[[subview]] 311 %0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1] 312 : tensor<?x10xf32> into tensor<?x?xf32> 313 314 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] 315 // CHECK: return %[[r]] 316 return %0 : tensor<?x?xf32> 317} 318 319// ----- 320 321// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_1( 322func.func @tensor.insert_slice_rank_reducing_1( 323 %t1: tensor<?x?xf32>, %f: tensor<f32>, %idx1: index, %idx2: index) 324 -> tensor<?x?xf32> 325{ 326 // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?xf32> 327 // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref<?x?xf32> to memref<f32, strided<[], offset: ?>> 328 // CHECK: memref.copy {{.*}} : memref<f32> to memref<f32, strided<[], offset: ?>> 329 %0 = tensor.insert_slice %f into %t1[%idx1, %idx2][1, 1][1, 1] 330 : tensor<f32> into tensor<?x?xf32> 331 return %0 : tensor<?x?xf32> 332} 333 334// ----- 335 336// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_2( 337func.func @tensor.insert_slice_rank_reducing_2( 338 %t1: tensor<?x?x?x?x?x?x?xf32>, %t2: tensor<2x1x4x1x1xf32>, %i: index) 339 -> tensor<?x?x?x?x?x?x?xf32> 340{ 341 // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?x?x?x?x?x?xf32> 342 // CHECK: memref.subview %[[alloc]][{{.*}}] [1, 2, 1, 4, 1, 1, 1] [1, 1, 1, 1, 1, 1, 1] : memref<?x?x?x?x?x?x?xf32> to memref<2x1x4x1x1xf32, strided<[?, ?, ?, ?, ?], offset: ?>> 343 // CHECK: memref.copy {{.*}} : memref<2x1x4x1x1xf32> to memref<2x1x4x1x1xf32, strided<[?, ?, ?, ?, ?], offset: ?>> 344 %0 = tensor.insert_slice %t2 into %t1[%i, %i, %i, %i, %i, %i, %i][1, 2, 1, 4, 1, 1, 1][1, 1, 1, 1, 1, 1, 1] 345 : tensor<2x1x4x1x1xf32> into tensor<?x?x?x?x?x?x?xf32> 346 return %0 : tensor<?x?x?x?x?x?x?xf32> 347} 348 349// ----- 350 351// CHECK-LABEL: func @tensor.insert( 352// CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index, 353// CHECK-SAME: %[[f:.*]]: f32 354func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> { 355 // CHECK-DAG: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> 356 // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<5xf32> to memref<5xf32> 357 // CHECK: memref.copy %[[m1]], %[[alloc]] 358 // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]] 359 %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32> 360 361 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] 362 // CHECK: return %[[r]] 363 return %0 : tensor<5xf32> 364} 365 366// ----- 367 368// CHECK-LABEL: func @tensor.expand_shape( 369// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32> 370func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> { 371 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] 372 // CHECK: %[[C0:.*]] = arith.constant 0 : index 373 // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32> 374 // CHECK: %[[C2:.*]] = arith.constant 2 : index 375 // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index 376 // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32> 377 %0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10] 378 : tensor<?x10xf32> into tensor<2x?x10xf32> 379 380 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] 381 // CHECK: return %[[r]] 382 return %0 : tensor<2x?x10xf32> 383} 384 385// ----- 386 387// CHECK-LABEL: func @tensor.expand_shape_of_slice( 388// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32> 389func.func @tensor.expand_shape_of_slice( 390 %t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> { 391 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : 392 // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>> 393 %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] : 394 tensor<?x20xf32> to tensor<?x10xf32> 395 // CHECK: %[[C7:.*]] = arith.constant 7 : index 396 // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index 397 // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>> 398 %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] : 399 tensor<?x10xf32> into tensor<?x7x2x5xf32> 400 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] 401 // CHECK: return %[[r]] 402 return %1 : tensor<?x7x2x5xf32> 403} 404 405// ----- 406 407// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice( 408// CHECK-SAME: %[[t1:.*]]: tensor<?xf32> 409func.func @tensor.expand_shape_of_scalar_slice( 410 %t1: tensor<?xf32>, %o1: index, %s1: index) -> tensor<1xf32> { 411 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?xf32> to memref<?xf32> 412 // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref<?xf32> to memref<f32, strided<[], offset: ?>> 413 %0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32> 414 // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] output_shape [1] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>> 415 %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<f32> into tensor<1xf32> 416 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] 417 // CHECK: return %[[r]] 418 return %1 : tensor<1xf32> 419} 420 421// ----- 422 423// CHECK-LABEL: func @tensor.collapse_shape( 424// CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32> 425func.func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> { 426 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<2x?x?xf32> to memref<2x?x?xf32> 427 // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [ 428 // CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref<?x?xf32> 429 %0 = tensor.collapse_shape %t1 [[0, 1], [2]] 430 : tensor<2x?x?xf32> into tensor<?x?xf32> 431 432 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]] 433 // CHECK: return %[[r]] 434 return %0 : tensor<?x?xf32> 435} 436 437// ----- 438 439// CHECK-LABEL: func @tensor.collapse_shape_to_scalar( 440// CHECK-SAME: %[[t1:.*]]: tensor<1x1x1xf32> 441func.func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor<f32> { 442 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<1x1x1xf32> to memref<1x1x1xf32> 443 // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [] : memref<1x1x1xf32> into memref<f32> 444 %0 = tensor.collapse_shape %t1 [] 445 : tensor<1x1x1xf32> into tensor<f32> 446 447 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]] 448 // CHECK: return %[[r]] 449 return %0 : tensor<f32> 450} 451 452// ----- 453 454// CHECK-LABEL: func @tensor.collapse_shape_of_slice( 455func.func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor<i32> { 456 // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>> 457 %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32> 458 // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, strided<[1], offset: 1>> into memref<i32, strided<[], offset: 1>> 459 %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32> 460 return %1 : tensor<i32> 461} 462 463// ----- 464 465// CHECK-LABEL: func @tensor.collapse_shape_of_slice2( 466func.func @tensor.collapse_shape_of_slice2( 467 %arg0: tensor<?x?x?x?xi64>, %o1: index, %o2: index, %o3: index, %o4: index) 468 -> tensor<87x63648xi64> { 469 // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref<?x?x?x?xi64> to memref<87x78x68x12xi64, strided{{.*}}> 470 %0 = tensor.extract_slice %arg0[%o1, %o2, %o3, %o4] [87, 78, 68, 12] [1, 1, 1, 1] : tensor<?x?x?x?xi64> to tensor<87x78x68x12xi64> 471 472 // This memref may not be collapsible, so the buffer must be copied to get rid 473 // of the layout map. 474 // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<87x78x68x12xi64> 475 // CHECK: memref.copy %[[subview]], %[[alloc]] 476 // CHECK: memref.collapse_shape %[[alloc]] [ 477 // CHECK-SAME: [0], [1, 2, 3]] : memref<87x78x68x12xi64> into memref<87x63648xi64> 478 %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64> 479 return %1 : tensor<87x63648xi64> 480} 481 482// ----- 483 484// CHECK-LABEL: func @tensor.collapse_shape_of_slice3( 485// CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32> 486func.func @tensor.collapse_shape_of_slice3(%t1: tensor<1x2xf32>) -> tensor<1xf32> { 487 // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, strided<[2, 1]>> 488 %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32> 489 // CHECK: memref.collapse_shape %{{.*}} [ 490 // CHECK-SAME: [0, 1]] : memref<1x1xf32, strided<[2, 1]>> into memref<1xf32, strided<[2]>> 491 %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32> 492 return %1 : tensor<1xf32> 493} 494 495// ----- 496 497// CHECK-LABEL: func @tensor.collapse_shape_of_slice4( 498// CHECK-SAME: %[[t1:.*]]: tensor<?x2x4xf32>, 499// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<8xf32> { 500func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: index, %size: index) -> tensor<8xf32> { 501 // CHECK: memref.subview %{{.*}} : memref<?x2x4xf32> to memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>> 502 %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor<?x2x4xf32> to tensor<4x2x1xf32> 503 // CHECK: memref.collapse_shape %{{.*}} [ 504 // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>> into memref<8xf32, strided<[4], offset: ?>> 505 %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32> 506 return %ret: tensor<8xf32> 507} 508 509// ----- 510 511// CHECK-LABEL: func @tensor.collapse_shape_of_slice5( 512func.func @tensor.collapse_shape_of_slice5(%arg0: tensor<2x2x2xi64>) -> tensor<4xi64> { 513 // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref<2x2x2xi64> to memref<2x1x2xi64, {{.*}}> 514 %0 = tensor.extract_slice %arg0[0, 0, 0] [2, 1, 2] [1, 1, 1] : tensor<2x2x2xi64> to tensor<2x1x2xi64> 515 516 // This memref is not collapsible, so the buffer must be copied to get rid of 517 // the layout map. 518 // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<2x1x2xi64> 519 // CHECK: memref.copy %[[subview]], %[[alloc]] 520 // CHECK: memref.collapse_shape %[[alloc]] [ 521 // CHECK-SAME: [0, 1, 2]] : memref<2x1x2xi64> into memref<4xi64> 522 %1 = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<2x1x2xi64> into tensor<4xi64> 523 return %1 : tensor<4xi64> 524} 525 526// ----- 527 528// CHECK-LABEL: func @tensor.reshape( 529// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32> 530func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> { 531 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?x10xf32> to memref<?x10xf32> 532 533 // CHECK: %[[two:.*]] = arith.constant 2 : i64 534 %two = arith.constant 2 : i64 535 // CHECK: %[[five:.*]] = arith.constant 5 : i64 536 %five = arith.constant 5 : i64 537 538 // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3xi64> 539 // CHECK: %[[zero_idx:.*]] = arith.constant 0 : index 540 // CHECK: %[[one_idx:.*]] = arith.constant 1 : index 541 // CHECK: %[[two_idx:.*]] = arith.constant 2 : index 542 // CHECK: memref.store %[[two]], %[[alloc]][%[[zero_idx]]] : memref<3xi64> 543 // CHECK: memref.store %[[two]], %[[alloc]][%[[one_idx]]] : memref<3xi64> 544 // CHECK: memref.store %[[five]], %[[alloc]][%[[two_idx]]] : memref<3xi64> 545 %shape = tensor.from_elements %two, %two, %five : tensor<3xi64> 546 547 // CHECK: %[[reshaped:.*]] = memref.reshape %[[m1]](%[[alloc]]) : (memref<?x10xf32>, memref<3xi64>) -> memref<2x2x5xf32> 548 %reshaped = tensor.reshape %t1(%shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<2x2x5xf32> 549 550 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[reshaped]] 551 // CHECK: return %[[r]] 552 return %reshaped : tensor<2x2x5xf32> 553} 554 555// ----- 556 557// CHECK: #[[$sum_map_1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> 558// CHECK: #[[$sum_map_2:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> 559// CHECK-LABEL: func @tensor.pad( 560// CHECK-SAME: %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index 561func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index, 562 %h2: index) -> tensor<?x?xindex> { 563 // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?x10xindex> to memref<?x10xindex> 564 // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 565 // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index 566 // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] 567 // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] 568 // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map_1]]()[%[[h1]], %[[dim0]]] 569 // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map_2]]()[%[[l2]], %[[h2]]] 570 // CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref<?x?xindex> 571 // CHECK: %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]] 572 // CHECK: %[[mapped:.*]] = linalg.map 573 // CHECK: outs(%[[alloc_t]] : tensor<?x?xindex>) 574 // CHECK: %[[index0:.*]] = linalg.index 0 575 // CHECK: %[[index1:.*]] = linalg.index 1 576 // CHECK: %[[mul:.*]] = arith.muli %[[index0]], %[[index1]] 577 // CHECK: linalg.yield %[[mul]] 578 // CHECK: } 579 // CHECK: %[[mapped_m:.*]] = bufferization.to_memref %[[mapped]] 580 // CHECK: %[[subview:.*]] = memref.subview %[[mapped_m]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] 581 // CHECK: memref.copy %[[m1]], %[[subview]] 582 %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] { 583 ^bb0(%arg0: index, %arg1: index): 584 %m = arith.muli %arg0, %arg1 : index 585 tensor.yield %m : index 586 } : tensor<?x10xindex> to tensor<?x?xindex> 587 588 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[mapped_m]] 589 // CHECK: return %[[r]] : tensor<?x?xindex> 590 return %0 : tensor<?x?xindex> 591} 592 593// ----- 594 595// CHECK-LABEL: func @tensor.splat( 596// CHECK-SAME: %[[F:.*]]: f32) 597// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4xf32> 598// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] 599// CHECK: %[[MAPPED:.*]] = linalg.map 600// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4xf32>) 601// CHECK: linalg.yield %[[F]] 602// CHECK: } 603// CHECK: return %[[MAPPED]] : tensor<10x2x4xf32> 604// CHECK: } 605func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> { 606 %t = tensor.splat %f : tensor<10x2x4xf32> 607 return %t : tensor<10x2x4xf32> 608} 609 610// ----- 611 612// CHECK-LABEL: func @tensor.splat_dynamic( 613// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32 614// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index 615// CHECK-SAME: %[[N:[a-zA-Z0-9_]+]]: index 616// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32> 617// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] 618// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>) 619// CHECK: () { 620// CHECK: linalg.yield %[[F]] : f32 621// CHECK: } 622// CHECK: return %[[MAPPED]] : tensor<?x3x?xf32> 623// CHECK: } 624func.func @tensor.splat_dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf32> { 625 %0 = tensor.splat %f[%m, %n] : tensor<?x3x?xf32> 626 return %0 : tensor<?x3x?xf32> 627} 628 629// ----- 630 631// CHECK-LABEL: func.func @parallel_insert_slice_copy_before_write 632func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: tensor<4xf32>) { 633 %c1 = arith.constant 1 : index 634 %num_threads = arith.constant 4 : index 635 636 // CHECK: scf.forall {{.*}} { 637 %result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<4xf32> { 638 %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<4xf32> to tensor<1xf32> 639 scf.forall.in_parallel { 640 // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>> 641 // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>> 642 tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : 643 tensor<1xf32> into tensor<4xf32> 644 } 645 } 646 // CHECK: } 647 return 648} 649