1// RUN: mlir-opt \ 2// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \ 3// RUN: %s | FileCheck %s 4 5mesh.mesh @mesh_1d_4(shape = 4) 6 7// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets 8func.func @tensor_empty_static_sharded_dims_offsets() -> () { 9 %b = tensor.empty() : tensor<8x16xf32> 10 %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding 11 %sharded= mesh.shard %b to %sharding : tensor<8x16xf32> 12 // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding 13 // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index 14 // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index 15 // CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32> 16 17 return 18} 19 20// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets 21// CHECK-SAME: %[[A0:.*]]: index 22func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () { 23 %b = tensor.empty(%arg0) : tensor<8x?xf32> 24 %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding 25 %sharded= mesh.shard %b to %sharding : tensor<8x?xf32> 26 // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding 27 // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index 28 // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index 29 // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32> 30 31 return 32} 33 34// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes 35func.func @tensor_empty_same_static_dims_sizes() -> () { 36 %b = tensor.empty() : tensor<16x16xf32> 37 %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding 38 %sharded= mesh.shard %b to %sharding : tensor<16x16xf32> 39 // CHECK-NEXT: tensor.empty() : tensor<4x16xf32> 40 41 return 42} 43