xref: /dpdk/drivers/ml/cnxk/mvtvm_ml_dev.c (revision 8df71650e9fdc6346f09b7a57e86cded7b553152)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2023 Marvell.
3  */
4 
5 #include <rte_kvargs.h>
6 #include <rte_mldev.h>
7 #include <rte_mldev_pmd.h>
8 
9 #include <bus_vdev_driver.h>
10 
11 #include <roc_api.h>
12 
13 #include "cnxk_ml_dev.h"
14 
15 #define MVTVM_ML_DEV_MAX_QPS	      "max_qps"
16 #define MVTVM_ML_DEV_CACHE_MODEL_DATA "cache_model_data"
17 
18 #define MVTVM_ML_DEV_MAX_QPS_DEFAULT	      32
19 #define CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT 1
20 
21 static const char *const valid_args[] = {MVTVM_ML_DEV_MAX_QPS, MVTVM_ML_DEV_CACHE_MODEL_DATA, NULL};
22 
23 static int
24 parse_integer_arg(const char *key __rte_unused, const char *value, void *extra_args)
25 {
26 	int *i = (int *)extra_args;
27 
28 	*i = atoi(value);
29 	if (*i < 0) {
30 		plt_err("Argument has to be positive.");
31 		return -EINVAL;
32 	}
33 
34 	return 0;
35 }
36 
37 static int
38 parse_uint_arg(const char *key __rte_unused, const char *value, void *extra_args)
39 {
40 	int i;
41 	char *end;
42 	errno = 0;
43 
44 	i = strtol(value, &end, 10);
45 	if (*end != 0 || errno != 0 || i < 0)
46 		return -EINVAL;
47 
48 	*((uint32_t *)extra_args) = i;
49 
50 	return 0;
51 }
52 
53 static int
54 mvtvm_mldev_parse_devargs(const char *args, struct mvtvm_ml_dev *mvtvm_mldev)
55 {
56 	bool cache_model_data_set = false;
57 	struct rte_kvargs *kvlist = NULL;
58 	bool max_qps_set = false;
59 	int ret = 0;
60 
61 	if (args == NULL)
62 		goto check_args;
63 
64 	kvlist = rte_kvargs_parse(args, valid_args);
65 	if (kvlist == NULL) {
66 		plt_err("Error parsing %s devargs", "MLDEV_NAME_MVTVM_PMD");
67 		return -EINVAL;
68 	}
69 
70 	if (rte_kvargs_count(kvlist, MVTVM_ML_DEV_MAX_QPS) == 1) {
71 		ret = rte_kvargs_process(kvlist, MVTVM_ML_DEV_MAX_QPS, &parse_uint_arg,
72 					 &mvtvm_mldev->max_nb_qpairs);
73 		if (ret < 0) {
74 			plt_err("Error processing arguments, key = %s", MVTVM_ML_DEV_MAX_QPS);
75 			ret = -EINVAL;
76 			goto exit;
77 		}
78 		max_qps_set = true;
79 	}
80 
81 	if (rte_kvargs_count(kvlist, MVTVM_ML_DEV_CACHE_MODEL_DATA) == 1) {
82 		ret = rte_kvargs_process(kvlist, MVTVM_ML_DEV_CACHE_MODEL_DATA, &parse_integer_arg,
83 					 &mvtvm_mldev->cache_model_data);
84 		if (ret < 0) {
85 			plt_err("Error processing arguments, key = %s",
86 				MVTVM_ML_DEV_CACHE_MODEL_DATA);
87 			ret = -EINVAL;
88 			goto exit;
89 		}
90 		cache_model_data_set = true;
91 	}
92 
93 check_args:
94 	if (!max_qps_set)
95 		mvtvm_mldev->max_nb_qpairs = MVTVM_ML_DEV_MAX_QPS_DEFAULT;
96 	plt_ml_dbg("ML: %s = %u", MVTVM_ML_DEV_MAX_QPS, mvtvm_mldev->max_nb_qpairs);
97 
98 	if (!cache_model_data_set) {
99 		mvtvm_mldev->cache_model_data = CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT;
100 	} else {
101 		if ((mvtvm_mldev->cache_model_data < 0) || (mvtvm_mldev->cache_model_data > 1)) {
102 			plt_err("Invalid argument, %s = %d", MVTVM_ML_DEV_CACHE_MODEL_DATA,
103 				mvtvm_mldev->cache_model_data);
104 			ret = -EINVAL;
105 			goto exit;
106 		}
107 	}
108 	plt_ml_dbg("ML: %s = %d", MVTVM_ML_DEV_CACHE_MODEL_DATA, mvtvm_mldev->cache_model_data);
109 
110 exit:
111 	rte_kvargs_free(kvlist);
112 
113 	return ret;
114 }
115 
116 static int
117 mvtvm_ml_vdev_probe(struct rte_vdev_device *vdev)
118 {
119 	struct rte_ml_dev_pmd_init_params init_params;
120 	struct mvtvm_ml_dev *mvtvm_mldev;
121 	struct cnxk_ml_dev *cnxk_mldev;
122 	struct rte_ml_dev *dev;
123 	const char *input_args;
124 	const char *name;
125 	int ret = 0;
126 
127 	if (cnxk_ml_dev_initialized == 1) {
128 		plt_err("ML CNXK device already initialized!");
129 		plt_err("Not creating ml_mvtvm vdev!");
130 		return 0;
131 	}
132 
133 	init_params = (struct rte_ml_dev_pmd_init_params){
134 		.socket_id = rte_socket_id(), .private_data_size = sizeof(struct cnxk_ml_dev)};
135 
136 	name = rte_vdev_device_name(vdev);
137 	if (name == NULL)
138 		return -EINVAL;
139 	input_args = rte_vdev_device_args(vdev);
140 
141 	dev = rte_ml_dev_pmd_create(name, &vdev->device, &init_params);
142 	if (dev == NULL) {
143 		ret = -EFAULT;
144 		goto error_exit;
145 	}
146 
147 	cnxk_mldev = dev->data->dev_private;
148 	cnxk_mldev->mldev = dev;
149 	mvtvm_mldev = &cnxk_mldev->mvtvm_mldev;
150 	mvtvm_mldev->vdev = vdev;
151 
152 	ret = mvtvm_mldev_parse_devargs(input_args, mvtvm_mldev);
153 	if (ret < 0)
154 		goto error_exit;
155 
156 	dev->dev_ops = &cnxk_ml_ops;
157 	dev->enqueue_burst = NULL;
158 	dev->dequeue_burst = NULL;
159 	dev->op_error_get = NULL;
160 
161 	cnxk_ml_dev_initialized = 1;
162 	cnxk_mldev->type = CNXK_ML_DEV_TYPE_VDEV;
163 
164 	return 0;
165 
166 error_exit:
167 	plt_err("Could not create device: ml_mvtvm");
168 
169 	return ret;
170 }
171 
172 static int
173 mvtvm_ml_vdev_remove(struct rte_vdev_device *vdev)
174 {
175 	struct rte_ml_dev *dev;
176 	const char *name;
177 
178 	name = rte_vdev_device_name(vdev);
179 	if (name == NULL)
180 		return -EINVAL;
181 
182 	dev = rte_ml_dev_pmd_get_named_dev(name);
183 	if (dev == NULL)
184 		return -ENODEV;
185 
186 	return rte_ml_dev_pmd_destroy(dev);
187 }
188 
189 static struct rte_vdev_driver mvtvm_mldev_pmd = {.probe = mvtvm_ml_vdev_probe,
190 						 .remove = mvtvm_ml_vdev_remove};
191 
192 RTE_PMD_REGISTER_VDEV(MLDEV_NAME_MVTVM_PMD, mvtvm_mldev_pmd);
193 
194 RTE_PMD_REGISTER_PARAM_STRING(MLDEV_NAME_MVTVM_PMD,
195 			      MVTVM_ML_DEV_MAX_QPS "=<int>" MVTVM_ML_DEV_CACHE_MODEL_DATA "=<0|1>");
196