1// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s 2 3mesh.mesh @mesh_1d(shape = ?) 4 5// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh 6func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh( 7 // CHECK: %[[ARG:.*]]: tensor<?xf16> 8 %arg0: tensor<?xf16> 9// CHECK-SAME: -> tensor<?xf16> { 10) -> tensor<?xf16> { 11 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 12 // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index 13 // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index 14 // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16> 15 // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index 16 // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index 17 // CHECK: cf.assert %[[AXIS_SIZE_CHECK]] 18 // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index 19 // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index 20 // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16> 21 %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16> 22 // CHECK: return %[[RESULT]] : tensor<?xf16> 23 return %0 : tensor<?xf16> 24} 25 26// ----- 27 28mesh.mesh @mesh_1d(shape = 2) 29 30// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh 31func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh( 32 // CHECK: %[[ARG:.*]]: tensor<2xf16> 33 %arg0: tensor<2xf16> 34// CHECK-SAME: -> tensor<1xf16> { 35) -> tensor<1xf16> { 36 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 37 // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index 38 // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16> 39 // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16> 40 %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16> 41 // CHECK: return %[[RESULT]] : tensor<1xf16> 42 return %0 : tensor<1xf16> 43} 44 45// ----- 46 47// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)> 48 49mesh.mesh @mesh_4d(shape = ?x?x?x?) 50 51// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh 52func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh( 53 // CHECK: %[[ARG:.*]]: tensor<?x?xf16> 54 %arg0 : tensor<?x?xf16> 55// CHECK-SAME: -> tensor<?x?xf16> { 56) -> tensor<?x?xf16> { 57 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 58 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 59 // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index 60 // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index 61 // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index 62 // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16> 63 // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index 64 // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index 65 // CHECK: cf.assert %[[AXIS_SIZE_CHECK]] 66 // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index 67 // CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1] 68 // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16> 69 // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index 70 // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16> 71 %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16> 72 // CHECK: return %[[RESULT]] : tensor<?x?xf16> 73 return %0 : tensor<?x?xf16> 74} 75