xref: /dpdk/lib/mldev/rte_mldev.c (revision 86dfed2a8ed704e013f054985a92d46f07ff48d1)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2022 Marvell.
3  */
4 
5 #include <rte_errno.h>
6 #include <rte_log.h>
7 #include <rte_mldev.h>
8 #include <rte_mldev_pmd.h>
9 
10 #include <stdlib.h>
11 
12 static struct rte_ml_dev_global ml_dev_globals = {
13 	.devs = NULL, .data = NULL, .nb_devs = 0, .max_devs = RTE_MLDEV_DEFAULT_MAX};
14 
15 /*
16  * Private data structure of an operation pool.
17  *
18  * A structure that contains ml op_pool specific data that is
19  * appended after the mempool structure (in private data).
20  */
21 struct rte_ml_op_pool_private {
22 	uint16_t user_size;
23 	/*< Size of private user data with each operation. */
24 };
25 
26 struct rte_ml_dev *
27 rte_ml_dev_pmd_get_dev(int16_t dev_id)
28 {
29 	return &ml_dev_globals.devs[dev_id];
30 }
31 
32 struct rte_ml_dev *
33 rte_ml_dev_pmd_get_named_dev(const char *name)
34 {
35 	struct rte_ml_dev *dev;
36 	int16_t dev_id;
37 
38 	if (name == NULL)
39 		return NULL;
40 
41 	for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) {
42 		dev = rte_ml_dev_pmd_get_dev(dev_id);
43 		if ((dev->attached == ML_DEV_ATTACHED) && (strcmp(dev->data->name, name) == 0))
44 			return dev;
45 	}
46 
47 	return NULL;
48 }
49 
50 struct rte_ml_dev *
51 rte_ml_dev_pmd_allocate(const char *name, uint8_t socket_id)
52 {
53 	char mz_name[RTE_MEMZONE_NAMESIZE];
54 	const struct rte_memzone *mz;
55 	struct rte_ml_dev *dev;
56 	int16_t dev_id;
57 
58 	/* implicit initialization of library before adding first device */
59 	if (ml_dev_globals.devs == NULL) {
60 		if (rte_ml_dev_init(RTE_MLDEV_DEFAULT_MAX) != 0)
61 			return NULL;
62 	}
63 
64 	if (rte_ml_dev_pmd_get_named_dev(name) != NULL) {
65 		RTE_MLDEV_LOG(ERR, "ML device with name %s already allocated!", name);
66 		return NULL;
67 	}
68 
69 	/* Get a free device ID */
70 	for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) {
71 		dev = rte_ml_dev_pmd_get_dev(dev_id);
72 		if (dev->attached == ML_DEV_DETACHED)
73 			break;
74 	}
75 
76 	if (dev_id == ml_dev_globals.max_devs) {
77 		RTE_MLDEV_LOG(ERR, "Reached maximum number of ML devices");
78 		return NULL;
79 	}
80 
81 	if (dev->data == NULL) {
82 		/* Reserve memzone name */
83 		sprintf(mz_name, "rte_ml_dev_data_%d", dev_id);
84 		if (rte_eal_process_type() == RTE_PROC_PRIMARY) {
85 			mz = rte_memzone_reserve(mz_name, sizeof(struct rte_ml_dev_data), socket_id,
86 						 0);
87 			RTE_MLDEV_LOG(DEBUG, "PRIMARY: reserved memzone for %s (%p)", mz_name, mz);
88 		} else {
89 			mz = rte_memzone_lookup(mz_name);
90 			RTE_MLDEV_LOG(DEBUG, "SECONDARY: looked up memzone for %s (%p)", mz_name,
91 				      mz);
92 		}
93 
94 		if (mz == NULL)
95 			return NULL;
96 
97 		ml_dev_globals.data[dev_id] = mz->addr;
98 		if (rte_eal_process_type() == RTE_PROC_PRIMARY)
99 			memset(ml_dev_globals.data[dev_id], 0, sizeof(struct rte_ml_dev_data));
100 
101 		dev->data = ml_dev_globals.data[dev_id];
102 		if (rte_eal_process_type() == RTE_PROC_PRIMARY) {
103 			strlcpy(dev->data->name, name, RTE_ML_STR_MAX);
104 			dev->data->dev_id = dev_id;
105 			dev->data->socket_id = socket_id;
106 			dev->data->dev_started = 0;
107 			RTE_MLDEV_LOG(DEBUG, "PRIMARY: init mldev data");
108 		}
109 
110 		RTE_MLDEV_LOG(DEBUG, "Data for %s: dev_id %d, socket %u", dev->data->name,
111 			      dev->data->dev_id, dev->data->socket_id);
112 
113 		dev->attached = ML_DEV_ATTACHED;
114 		ml_dev_globals.nb_devs++;
115 	}
116 
117 	dev->enqueue_burst = NULL;
118 	dev->dequeue_burst = NULL;
119 
120 	return dev;
121 }
122 
123 int
124 rte_ml_dev_pmd_release(struct rte_ml_dev *dev)
125 {
126 	char mz_name[RTE_MEMZONE_NAMESIZE];
127 	const struct rte_memzone *mz;
128 	int16_t dev_id;
129 	int ret = 0;
130 
131 	if (dev == NULL)
132 		return -EINVAL;
133 
134 	dev_id = dev->data->dev_id;
135 
136 	/* Memzone lookup */
137 	sprintf(mz_name, "rte_ml_dev_data_%d", dev_id);
138 	mz = rte_memzone_lookup(mz_name);
139 	if (mz == NULL)
140 		return -ENOMEM;
141 
142 	RTE_ASSERT(ml_dev_globals.data[dev_id] == mz->addr);
143 	ml_dev_globals.data[dev_id] = NULL;
144 
145 	if (rte_eal_process_type() == RTE_PROC_PRIMARY) {
146 		RTE_MLDEV_LOG(DEBUG, "PRIMARY: free memzone of %s (%p)", mz_name, mz);
147 		ret = rte_memzone_free(mz);
148 	} else {
149 		RTE_MLDEV_LOG(DEBUG, "SECONDARY: don't free memzone of %s (%p)", mz_name, mz);
150 	}
151 
152 	dev->attached = ML_DEV_DETACHED;
153 	ml_dev_globals.nb_devs--;
154 
155 	return ret;
156 }
157 
158 int
159 rte_ml_dev_init(size_t dev_max)
160 {
161 	if (dev_max == 0 || dev_max > INT16_MAX) {
162 		RTE_MLDEV_LOG(ERR, "Invalid dev_max = %zu (> %d)\n", dev_max, INT16_MAX);
163 		rte_errno = EINVAL;
164 		return -rte_errno;
165 	}
166 
167 	/* No lock, it must be called before or during first probing. */
168 	if (ml_dev_globals.devs != NULL) {
169 		RTE_MLDEV_LOG(ERR, "Device array already initialized");
170 		rte_errno = EBUSY;
171 		return -rte_errno;
172 	}
173 
174 	ml_dev_globals.devs = calloc(dev_max, sizeof(struct rte_ml_dev));
175 	if (ml_dev_globals.devs == NULL) {
176 		RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library");
177 		rte_errno = ENOMEM;
178 		return -rte_errno;
179 	}
180 
181 	ml_dev_globals.data = calloc(dev_max, sizeof(struct rte_ml_dev_data *));
182 	if (ml_dev_globals.data == NULL) {
183 		RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library");
184 		rte_errno = ENOMEM;
185 		return -rte_errno;
186 	}
187 
188 	ml_dev_globals.max_devs = dev_max;
189 
190 	return 0;
191 }
192 
193 uint16_t
194 rte_ml_dev_count(void)
195 {
196 	return ml_dev_globals.nb_devs;
197 }
198 
199 int
200 rte_ml_dev_is_valid_dev(int16_t dev_id)
201 {
202 	struct rte_ml_dev *dev = NULL;
203 
204 	if (dev_id >= ml_dev_globals.max_devs || ml_dev_globals.devs[dev_id].data == NULL)
205 		return 0;
206 
207 	dev = rte_ml_dev_pmd_get_dev(dev_id);
208 	if (dev->attached != ML_DEV_ATTACHED)
209 		return 0;
210 	else
211 		return 1;
212 }
213 
214 int
215 rte_ml_dev_socket_id(int16_t dev_id)
216 {
217 	struct rte_ml_dev *dev;
218 
219 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
220 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
221 		return -EINVAL;
222 	}
223 
224 	dev = rte_ml_dev_pmd_get_dev(dev_id);
225 
226 	return dev->data->socket_id;
227 }
228 
229 int
230 rte_ml_dev_info_get(int16_t dev_id, struct rte_ml_dev_info *dev_info)
231 {
232 	struct rte_ml_dev *dev;
233 
234 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
235 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
236 		return -EINVAL;
237 	}
238 
239 	dev = rte_ml_dev_pmd_get_dev(dev_id);
240 	if (*dev->dev_ops->dev_info_get == NULL)
241 		return -ENOTSUP;
242 
243 	if (dev_info == NULL) {
244 		RTE_MLDEV_LOG(ERR, "Dev %d, dev_info cannot be NULL\n", dev_id);
245 		return -EINVAL;
246 	}
247 	memset(dev_info, 0, sizeof(struct rte_ml_dev_info));
248 
249 	return (*dev->dev_ops->dev_info_get)(dev, dev_info);
250 }
251 
252 int
253 rte_ml_dev_configure(int16_t dev_id, const struct rte_ml_dev_config *config)
254 {
255 	struct rte_ml_dev_info dev_info;
256 	struct rte_ml_dev *dev;
257 	int ret;
258 
259 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
260 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
261 		return -EINVAL;
262 	}
263 
264 	dev = rte_ml_dev_pmd_get_dev(dev_id);
265 	if (*dev->dev_ops->dev_configure == NULL)
266 		return -ENOTSUP;
267 
268 	if (dev->data->dev_started) {
269 		RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id);
270 		return -EBUSY;
271 	}
272 
273 	if (config == NULL) {
274 		RTE_MLDEV_LOG(ERR, "Dev %d, config cannot be NULL\n", dev_id);
275 		return -EINVAL;
276 	}
277 
278 	ret = rte_ml_dev_info_get(dev_id, &dev_info);
279 	if (ret < 0)
280 		return ret;
281 
282 	if (config->nb_queue_pairs > dev_info.max_queue_pairs) {
283 		RTE_MLDEV_LOG(ERR, "Device %d num of queues %u > %u\n", dev_id,
284 			      config->nb_queue_pairs, dev_info.max_queue_pairs);
285 		return -EINVAL;
286 	}
287 
288 	return (*dev->dev_ops->dev_configure)(dev, config);
289 }
290 
291 int
292 rte_ml_dev_close(int16_t dev_id)
293 {
294 	struct rte_ml_dev *dev;
295 
296 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
297 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
298 		return -EINVAL;
299 	}
300 
301 	dev = rte_ml_dev_pmd_get_dev(dev_id);
302 	if (*dev->dev_ops->dev_close == NULL)
303 		return -ENOTSUP;
304 
305 	/* Device must be stopped before it can be closed */
306 	if (dev->data->dev_started == 1) {
307 		RTE_MLDEV_LOG(ERR, "Device %d must be stopped before closing", dev_id);
308 		return -EBUSY;
309 	}
310 
311 	return (*dev->dev_ops->dev_close)(dev);
312 }
313 
314 int
315 rte_ml_dev_start(int16_t dev_id)
316 {
317 	struct rte_ml_dev *dev;
318 	int ret;
319 
320 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
321 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
322 		return -EINVAL;
323 	}
324 
325 	dev = rte_ml_dev_pmd_get_dev(dev_id);
326 	if (*dev->dev_ops->dev_start == NULL)
327 		return -ENOTSUP;
328 
329 	if (dev->data->dev_started != 0) {
330 		RTE_MLDEV_LOG(ERR, "Device %d is already started", dev_id);
331 		return -EBUSY;
332 	}
333 
334 	ret = (*dev->dev_ops->dev_start)(dev);
335 	if (ret == 0)
336 		dev->data->dev_started = 1;
337 
338 	return ret;
339 }
340 
341 int
342 rte_ml_dev_stop(int16_t dev_id)
343 {
344 	struct rte_ml_dev *dev;
345 	int ret;
346 
347 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
348 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
349 		return -EINVAL;
350 	}
351 
352 	dev = rte_ml_dev_pmd_get_dev(dev_id);
353 	if (*dev->dev_ops->dev_stop == NULL)
354 		return -ENOTSUP;
355 
356 	if (dev->data->dev_started == 0) {
357 		RTE_MLDEV_LOG(ERR, "Device %d is not started", dev_id);
358 		return -EBUSY;
359 	}
360 
361 	ret = (*dev->dev_ops->dev_stop)(dev);
362 	if (ret == 0)
363 		dev->data->dev_started = 0;
364 
365 	return ret;
366 }
367 
368 int
369 rte_ml_dev_queue_pair_setup(int16_t dev_id, uint16_t queue_pair_id,
370 			    const struct rte_ml_dev_qp_conf *qp_conf, int socket_id)
371 {
372 	struct rte_ml_dev *dev;
373 
374 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
375 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
376 		return -EINVAL;
377 	}
378 
379 	dev = rte_ml_dev_pmd_get_dev(dev_id);
380 	if (*dev->dev_ops->dev_queue_pair_setup == NULL)
381 		return -ENOTSUP;
382 
383 	if (queue_pair_id >= dev->data->nb_queue_pairs) {
384 		RTE_MLDEV_LOG(ERR, "Invalid queue_pair_id = %d", queue_pair_id);
385 		return -EINVAL;
386 	}
387 
388 	if (qp_conf == NULL) {
389 		RTE_MLDEV_LOG(ERR, "Dev %d, qp_conf cannot be NULL\n", dev_id);
390 		return -EINVAL;
391 	}
392 
393 	if (dev->data->dev_started) {
394 		RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id);
395 		return -EBUSY;
396 	}
397 
398 	return (*dev->dev_ops->dev_queue_pair_setup)(dev, queue_pair_id, qp_conf, socket_id);
399 }
400 
401 int
402 rte_ml_dev_stats_get(int16_t dev_id, struct rte_ml_dev_stats *stats)
403 {
404 	struct rte_ml_dev *dev;
405 
406 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
407 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
408 		return -EINVAL;
409 	}
410 
411 	dev = rte_ml_dev_pmd_get_dev(dev_id);
412 	if (*dev->dev_ops->dev_stats_get == NULL)
413 		return -ENOTSUP;
414 
415 	if (stats == NULL) {
416 		RTE_MLDEV_LOG(ERR, "Dev %d, stats cannot be NULL\n", dev_id);
417 		return -EINVAL;
418 	}
419 	memset(stats, 0, sizeof(struct rte_ml_dev_stats));
420 
421 	return (*dev->dev_ops->dev_stats_get)(dev, stats);
422 }
423 
424 void
425 rte_ml_dev_stats_reset(int16_t dev_id)
426 {
427 	struct rte_ml_dev *dev;
428 
429 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
430 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
431 		return;
432 	}
433 
434 	dev = rte_ml_dev_pmd_get_dev(dev_id);
435 	if (*dev->dev_ops->dev_stats_reset == NULL)
436 		return;
437 
438 	(*dev->dev_ops->dev_stats_reset)(dev);
439 }
440 
441 int
442 rte_ml_dev_xstats_names_get(int16_t dev_id, struct rte_ml_dev_xstats_map *xstats_map, uint32_t size)
443 {
444 	struct rte_ml_dev *dev;
445 
446 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
447 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
448 		return -EINVAL;
449 	}
450 
451 	dev = rte_ml_dev_pmd_get_dev(dev_id);
452 	if (*dev->dev_ops->dev_xstats_names_get == NULL)
453 		return -ENOTSUP;
454 
455 	return (*dev->dev_ops->dev_xstats_names_get)(dev, xstats_map, size);
456 }
457 
458 int
459 rte_ml_dev_xstats_by_name_get(int16_t dev_id, const char *name, uint16_t *stat_id, uint64_t *value)
460 {
461 	struct rte_ml_dev *dev;
462 
463 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
464 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
465 		return -EINVAL;
466 	}
467 
468 	dev = rte_ml_dev_pmd_get_dev(dev_id);
469 	if (*dev->dev_ops->dev_xstats_by_name_get == NULL)
470 		return -ENOTSUP;
471 
472 	if (name == NULL) {
473 		RTE_MLDEV_LOG(ERR, "Dev %d, name cannot be NULL\n", dev_id);
474 		return -EINVAL;
475 	}
476 
477 	if (value == NULL) {
478 		RTE_MLDEV_LOG(ERR, "Dev %d, value cannot be NULL\n", dev_id);
479 		return -EINVAL;
480 	}
481 
482 	return (*dev->dev_ops->dev_xstats_by_name_get)(dev, name, stat_id, value);
483 }
484 
485 int
486 rte_ml_dev_xstats_get(int16_t dev_id, const uint16_t *stat_ids, uint64_t *values, uint16_t nb_ids)
487 {
488 	struct rte_ml_dev *dev;
489 
490 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
491 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
492 		return -EINVAL;
493 	}
494 
495 	dev = rte_ml_dev_pmd_get_dev(dev_id);
496 	if (*dev->dev_ops->dev_xstats_get == NULL)
497 		return -ENOTSUP;
498 
499 	if (stat_ids == NULL) {
500 		RTE_MLDEV_LOG(ERR, "Dev %d, stat_ids cannot be NULL\n", dev_id);
501 		return -EINVAL;
502 	}
503 
504 	if (values == NULL) {
505 		RTE_MLDEV_LOG(ERR, "Dev %d, values cannot be NULL\n", dev_id);
506 		return -EINVAL;
507 	}
508 
509 	return (*dev->dev_ops->dev_xstats_get)(dev, stat_ids, values, nb_ids);
510 }
511 
512 int
513 rte_ml_dev_xstats_reset(int16_t dev_id, const uint16_t *stat_ids, uint16_t nb_ids)
514 {
515 	struct rte_ml_dev *dev;
516 
517 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
518 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
519 		return -EINVAL;
520 	}
521 
522 	dev = rte_ml_dev_pmd_get_dev(dev_id);
523 	if (*dev->dev_ops->dev_xstats_reset == NULL)
524 		return -ENOTSUP;
525 
526 	return (*dev->dev_ops->dev_xstats_reset)(dev, stat_ids, nb_ids);
527 }
528 
529 int
530 rte_ml_dev_dump(int16_t dev_id, FILE *fd)
531 {
532 	struct rte_ml_dev *dev;
533 
534 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
535 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
536 		return -EINVAL;
537 	}
538 
539 	dev = rte_ml_dev_pmd_get_dev(dev_id);
540 	if (*dev->dev_ops->dev_dump == NULL)
541 		return -ENOTSUP;
542 
543 	if (fd == NULL) {
544 		RTE_MLDEV_LOG(ERR, "Dev %d, file descriptor cannot be NULL\n", dev_id);
545 		return -EINVAL;
546 	}
547 
548 	return (*dev->dev_ops->dev_dump)(dev, fd);
549 }
550 
551 int
552 rte_ml_dev_selftest(int16_t dev_id)
553 {
554 	struct rte_ml_dev *dev;
555 
556 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
557 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
558 		return -EINVAL;
559 	}
560 
561 	dev = rte_ml_dev_pmd_get_dev(dev_id);
562 	if (*dev->dev_ops->dev_selftest == NULL)
563 		return -ENOTSUP;
564 
565 	return (*dev->dev_ops->dev_selftest)(dev);
566 }
567 
568 int
569 rte_ml_model_load(int16_t dev_id, struct rte_ml_model_params *params, uint16_t *model_id)
570 {
571 	struct rte_ml_dev *dev;
572 
573 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
574 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
575 		return -EINVAL;
576 	}
577 
578 	dev = rte_ml_dev_pmd_get_dev(dev_id);
579 	if (*dev->dev_ops->model_load == NULL)
580 		return -ENOTSUP;
581 
582 	if (params == NULL) {
583 		RTE_MLDEV_LOG(ERR, "Dev %d, params cannot be NULL\n", dev_id);
584 		return -EINVAL;
585 	}
586 
587 	if (model_id == NULL) {
588 		RTE_MLDEV_LOG(ERR, "Dev %d, model_id cannot be NULL\n", dev_id);
589 		return -EINVAL;
590 	}
591 
592 	return (*dev->dev_ops->model_load)(dev, params, model_id);
593 }
594 
595 int
596 rte_ml_model_unload(int16_t dev_id, uint16_t model_id)
597 {
598 	struct rte_ml_dev *dev;
599 
600 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
601 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
602 		return -EINVAL;
603 	}
604 
605 	dev = rte_ml_dev_pmd_get_dev(dev_id);
606 	if (*dev->dev_ops->model_unload == NULL)
607 		return -ENOTSUP;
608 
609 	return (*dev->dev_ops->model_unload)(dev, model_id);
610 }
611 
612 int
613 rte_ml_model_start(int16_t dev_id, uint16_t model_id)
614 {
615 	struct rte_ml_dev *dev;
616 
617 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
618 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
619 		return -EINVAL;
620 	}
621 
622 	dev = rte_ml_dev_pmd_get_dev(dev_id);
623 	if (*dev->dev_ops->model_start == NULL)
624 		return -ENOTSUP;
625 
626 	return (*dev->dev_ops->model_start)(dev, model_id);
627 }
628 
629 int
630 rte_ml_model_stop(int16_t dev_id, uint16_t model_id)
631 {
632 	struct rte_ml_dev *dev;
633 
634 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
635 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
636 		return -EINVAL;
637 	}
638 
639 	dev = rte_ml_dev_pmd_get_dev(dev_id);
640 	if (*dev->dev_ops->model_stop == NULL)
641 		return -ENOTSUP;
642 
643 	return (*dev->dev_ops->model_stop)(dev, model_id);
644 }
645 
646 int
647 rte_ml_model_info_get(int16_t dev_id, uint16_t model_id, struct rte_ml_model_info *model_info)
648 {
649 	struct rte_ml_dev *dev;
650 
651 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
652 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
653 		return -EINVAL;
654 	}
655 
656 	dev = rte_ml_dev_pmd_get_dev(dev_id);
657 	if (*dev->dev_ops->model_info_get == NULL)
658 		return -ENOTSUP;
659 
660 	if (model_info == NULL) {
661 		RTE_MLDEV_LOG(ERR, "Dev %d, model_id %u, model_info cannot be NULL\n", dev_id,
662 			      model_id);
663 		return -EINVAL;
664 	}
665 
666 	return (*dev->dev_ops->model_info_get)(dev, model_id, model_info);
667 }
668 
669 int
670 rte_ml_model_params_update(int16_t dev_id, uint16_t model_id, void *buffer)
671 {
672 	struct rte_ml_dev *dev;
673 
674 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
675 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
676 		return -EINVAL;
677 	}
678 
679 	dev = rte_ml_dev_pmd_get_dev(dev_id);
680 	if (*dev->dev_ops->model_params_update == NULL)
681 		return -ENOTSUP;
682 
683 	if (buffer == NULL) {
684 		RTE_MLDEV_LOG(ERR, "Dev %d, buffer cannot be NULL\n", dev_id);
685 		return -EINVAL;
686 	}
687 
688 	return (*dev->dev_ops->model_params_update)(dev, model_id, buffer);
689 }
690 
691 int
692 rte_ml_io_input_size_get(int16_t dev_id, uint16_t model_id, uint32_t nb_batches,
693 			 uint64_t *input_qsize, uint64_t *input_dsize)
694 {
695 	struct rte_ml_dev *dev;
696 
697 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
698 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
699 		return -EINVAL;
700 	}
701 
702 	dev = rte_ml_dev_pmd_get_dev(dev_id);
703 	if (*dev->dev_ops->io_input_size_get == NULL)
704 		return -ENOTSUP;
705 
706 	return (*dev->dev_ops->io_input_size_get)(dev, model_id, nb_batches, input_qsize,
707 						  input_dsize);
708 }
709 
710 int
711 rte_ml_io_output_size_get(int16_t dev_id, uint16_t model_id, uint32_t nb_batches,
712 			  uint64_t *output_qsize, uint64_t *output_dsize)
713 {
714 	struct rte_ml_dev *dev;
715 
716 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
717 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
718 		return -EINVAL;
719 	}
720 
721 	dev = rte_ml_dev_pmd_get_dev(dev_id);
722 	if (*dev->dev_ops->io_output_size_get == NULL)
723 		return -ENOTSUP;
724 
725 	return (*dev->dev_ops->io_output_size_get)(dev, model_id, nb_batches, output_qsize,
726 						   output_dsize);
727 }
728 
729 int
730 rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *dbuffer,
731 		   void *qbuffer)
732 {
733 	struct rte_ml_dev *dev;
734 
735 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
736 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
737 		return -EINVAL;
738 	}
739 
740 	dev = rte_ml_dev_pmd_get_dev(dev_id);
741 	if (*dev->dev_ops->io_quantize == NULL)
742 		return -ENOTSUP;
743 
744 	if (dbuffer == NULL) {
745 		RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL\n", dev_id);
746 		return -EINVAL;
747 	}
748 
749 	if (qbuffer == NULL) {
750 		RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL\n", dev_id);
751 		return -EINVAL;
752 	}
753 
754 	return (*dev->dev_ops->io_quantize)(dev, model_id, nb_batches, dbuffer, qbuffer);
755 }
756 
757 int
758 rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *qbuffer,
759 		     void *dbuffer)
760 {
761 	struct rte_ml_dev *dev;
762 
763 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
764 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
765 		return -EINVAL;
766 	}
767 
768 	dev = rte_ml_dev_pmd_get_dev(dev_id);
769 	if (*dev->dev_ops->io_dequantize == NULL)
770 		return -ENOTSUP;
771 
772 	if (qbuffer == NULL) {
773 		RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL\n", dev_id);
774 		return -EINVAL;
775 	}
776 
777 	if (dbuffer == NULL) {
778 		RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL\n", dev_id);
779 		return -EINVAL;
780 	}
781 
782 	return (*dev->dev_ops->io_dequantize)(dev, model_id, nb_batches, qbuffer, dbuffer);
783 }
784 
785 /** Initialise rte_ml_op mempool element */
786 static void
787 ml_op_init(struct rte_mempool *mempool, __rte_unused void *opaque_arg, void *_op_data,
788 	   __rte_unused unsigned int i)
789 {
790 	struct rte_ml_op *op = _op_data;
791 
792 	memset(_op_data, 0, mempool->elt_size);
793 	op->status = RTE_ML_OP_STATUS_NOT_PROCESSED;
794 	op->mempool = mempool;
795 }
796 
797 struct rte_mempool *
798 rte_ml_op_pool_create(const char *name, unsigned int nb_elts, unsigned int cache_size,
799 		      uint16_t user_size, int socket_id)
800 {
801 	struct rte_ml_op_pool_private *priv;
802 	struct rte_mempool *mp;
803 	unsigned int elt_size;
804 
805 	/* lookup mempool in case already allocated */
806 	mp = rte_mempool_lookup(name);
807 	elt_size = sizeof(struct rte_ml_op) + user_size;
808 
809 	if (mp != NULL) {
810 		priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp);
811 		if (mp->elt_size != elt_size || mp->cache_size < cache_size || mp->size < nb_elts ||
812 		    priv->user_size < user_size) {
813 			mp = NULL;
814 			RTE_MLDEV_LOG(ERR,
815 				      "Mempool %s already exists but with incompatible parameters",
816 				      name);
817 			return NULL;
818 		}
819 		return mp;
820 	}
821 
822 	mp = rte_mempool_create(name, nb_elts, elt_size, cache_size,
823 				sizeof(struct rte_ml_op_pool_private), NULL, NULL, ml_op_init, NULL,
824 				socket_id, 0);
825 	if (mp == NULL) {
826 		RTE_MLDEV_LOG(ERR, "Failed to create mempool %s", name);
827 		return NULL;
828 	}
829 
830 	priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp);
831 	priv->user_size = user_size;
832 
833 	return mp;
834 }
835 
836 void
837 rte_ml_op_pool_free(struct rte_mempool *mempool)
838 {
839 	if (mempool != NULL)
840 		rte_mempool_free(mempool);
841 }
842 
843 uint16_t
844 rte_ml_enqueue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops)
845 {
846 	struct rte_ml_dev *dev;
847 
848 #ifdef RTE_LIBRTE_ML_DEV_DEBUG
849 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
850 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
851 		rte_errno = -EINVAL;
852 		return 0;
853 	}
854 
855 	dev = rte_ml_dev_pmd_get_dev(dev_id);
856 	if (*dev->enqueue_burst == NULL) {
857 		rte_errno = -ENOTSUP;
858 		return 0;
859 	}
860 
861 	if (ops == NULL) {
862 		RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL\n", dev_id);
863 		rte_errno = -EINVAL;
864 		return 0;
865 	}
866 
867 	if (qp_id >= dev->data->nb_queue_pairs) {
868 		RTE_MLDEV_LOG(ERR, "Invalid qp_id %u\n", qp_id);
869 		rte_errno = -EINVAL;
870 		return 0;
871 	}
872 #else
873 	dev = rte_ml_dev_pmd_get_dev(dev_id);
874 #endif
875 
876 	return (*dev->enqueue_burst)(dev, qp_id, ops, nb_ops);
877 }
878 
879 uint16_t
880 rte_ml_dequeue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops)
881 {
882 	struct rte_ml_dev *dev;
883 
884 #ifdef RTE_LIBRTE_ML_DEV_DEBUG
885 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
886 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
887 		rte_errno = -EINVAL;
888 		return 0;
889 	}
890 
891 	dev = rte_ml_dev_pmd_get_dev(dev_id);
892 	if (*dev->dequeue_burst == NULL) {
893 		rte_errno = -ENOTSUP;
894 		return 0;
895 	}
896 
897 	if (ops == NULL) {
898 		RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL\n", dev_id);
899 		rte_errno = -EINVAL;
900 		return 0;
901 	}
902 
903 	if (qp_id >= dev->data->nb_queue_pairs) {
904 		RTE_MLDEV_LOG(ERR, "Invalid qp_id %u\n", qp_id);
905 		rte_errno = -EINVAL;
906 		return 0;
907 	}
908 #else
909 	dev = rte_ml_dev_pmd_get_dev(dev_id);
910 #endif
911 
912 	return (*dev->dequeue_burst)(dev, qp_id, ops, nb_ops);
913 }
914 
915 int
916 rte_ml_op_error_get(int16_t dev_id, struct rte_ml_op *op, struct rte_ml_op_error *error)
917 {
918 	struct rte_ml_dev *dev;
919 
920 #ifdef RTE_LIBRTE_ML_DEV_DEBUG
921 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
922 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
923 		return -EINVAL;
924 	}
925 
926 	dev = rte_ml_dev_pmd_get_dev(dev_id);
927 	if (*dev->op_error_get == NULL)
928 		return -ENOTSUP;
929 
930 	if (op == NULL) {
931 		RTE_MLDEV_LOG(ERR, "Dev %d, op cannot be NULL\n", dev_id);
932 		return -EINVAL;
933 	}
934 
935 	if (error == NULL) {
936 		RTE_MLDEV_LOG(ERR, "Dev %d, error cannot be NULL\n", dev_id);
937 		return -EINVAL;
938 	}
939 #else
940 	dev = rte_ml_dev_pmd_get_dev(dev_id);
941 #endif
942 
943 	return (*dev->op_error_get)(dev, op, error);
944 }
945 
946 RTE_LOG_REGISTER_DEFAULT(rte_ml_dev_logtype, INFO);
947