Lines Matching defs:odsState

400 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
402 build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
405 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
407 build(odsBuilder, odsState,
413 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
416 build(odsBuilder, odsState,
430 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
438 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
447 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
451 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
457 ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
466 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
472 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
475 build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
758 ::mlir::OperationState &odsState,
762 build(odsBuilder, odsState, resType, shape, sharding, device);
798 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
800 build(odsBuilder, odsState,
805 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
807 build(odsBuilder, odsState,
831 OperationState &odsState, MeshOp mesh) {
832 build(odsBuilder, odsState, mesh.getSymName());
1107 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1110 build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
1138 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1142 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1146 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1149 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,