xref: /llvm-project/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir (revision dc3258c617420e83caff63c93d548e0923b10791)
1// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
2
3mesh.mesh @mesh_1d(shape = ?)
4
5// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
6func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
7  // CHECK: %[[ARG:.*]]: tensor<?xf16>
8  %arg0: tensor<?xf16>
9// CHECK-SAME: -> tensor<?xf16> {
10) -> tensor<?xf16> {
11  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
12  // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
13  // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
14  // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16>
15  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
16  // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
17  // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
18  // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
19  // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
20  // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
21  %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16>
22  // CHECK: return %[[RESULT]] : tensor<?xf16>
23  return %0 : tensor<?xf16>
24}
25
26// -----
27
28mesh.mesh @mesh_1d(shape = 2)
29
30// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh
31func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
32  // CHECK: %[[ARG:.*]]: tensor<2xf16>
33  %arg0: tensor<2xf16>
34// CHECK-SAME: -> tensor<1xf16> {
35) -> tensor<1xf16> {
36  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
37  // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
38  // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
39  // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
40  %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16>
41  // CHECK: return %[[RESULT]] : tensor<1xf16>
42  return %0 : tensor<1xf16>
43}
44
45// -----
46
47// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
48
49mesh.mesh @mesh_4d(shape = ?x?x?x?)
50
51// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
52func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
53  // CHECK: %[[ARG:.*]]: tensor<?x?xf16>
54  %arg0 : tensor<?x?xf16>
55// CHECK-SAME: -> tensor<?x?xf16> {
56) -> tensor<?x?xf16> {
57  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
58  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
59  // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index
60  // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index
61  // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
62  // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
63  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
64  // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
65  // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
66  // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
67  // CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1]
68  // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
69  // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
70  // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
71  %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
72  // CHECK: return %[[RESULT]] : tensor<?x?xf16>
73  return %0 : tensor<?x?xf16>
74}
75