xref: /llvm-project/mlir/test/Dialect/Mesh/folding.mlir (revision 9a8437f50470e2658ca0b26bbc9f3da654c20dba)
1// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
2
3mesh.mesh @mesh0(shape = 4x?x2)
4mesh.mesh @mesh1(shape = 2x3)
5
6// CHECK-LABEL: func.func @mesh_shape_op_folding
7func.func @mesh_shape_op_folding() -> (index, index) {
8  // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
9  // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.mesh_shape @mesh0 axes = [1] : index
10  %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index
11  // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
12  return %0#0, %0#1 : index, index
13}
14
15// CHECK-LABEL: func.func @mesh_shape_op_folding_all_axes_static_mesh
16func.func @mesh_shape_op_folding_all_axes_static_mesh() -> (index, index) {
17  // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
18  // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
19  %0:2 = mesh.mesh_shape @mesh1 : index, index
20  // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
21  return %0#0, %0#1 : index, index
22}
23