xref: /dpdk/lib/mldev/rte_mldev.c (revision 804786f1012a35ee665e9d0fdf3bee1847a374fc)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2022 Marvell.
3  */
4 
5 #include <rte_errno.h>
6 #include <rte_log.h>
7 #include <rte_mldev.h>
8 #include <rte_mldev_pmd.h>
9 
10 #include <stdlib.h>
11 
12 static struct rte_ml_dev_global ml_dev_globals = {
13 	.devs = NULL, .data = NULL, .nb_devs = 0, .max_devs = RTE_MLDEV_DEFAULT_MAX};
14 
15 /*
16  * Private data structure of an operation pool.
17  *
18  * A structure that contains ml op_pool specific data that is
19  * appended after the mempool structure (in private data).
20  */
21 struct rte_ml_op_pool_private {
22 	uint16_t user_size;
23 	/*< Size of private user data with each operation. */
24 };
25 
26 struct rte_ml_dev *
27 rte_ml_dev_pmd_get_dev(int16_t dev_id)
28 {
29 	return &ml_dev_globals.devs[dev_id];
30 }
31 
32 struct rte_ml_dev *
33 rte_ml_dev_pmd_get_named_dev(const char *name)
34 {
35 	struct rte_ml_dev *dev;
36 	int16_t dev_id;
37 
38 	if (name == NULL)
39 		return NULL;
40 
41 	for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) {
42 		dev = rte_ml_dev_pmd_get_dev(dev_id);
43 		if ((dev->attached == ML_DEV_ATTACHED) && (strcmp(dev->data->name, name) == 0))
44 			return dev;
45 	}
46 
47 	return NULL;
48 }
49 
50 struct rte_ml_dev *
51 rte_ml_dev_pmd_allocate(const char *name, uint8_t socket_id)
52 {
53 	char mz_name[RTE_MEMZONE_NAMESIZE];
54 	const struct rte_memzone *mz;
55 	struct rte_ml_dev *dev;
56 	int16_t dev_id;
57 
58 	/* implicit initialization of library before adding first device */
59 	if (ml_dev_globals.devs == NULL) {
60 		if (rte_ml_dev_init(RTE_MLDEV_DEFAULT_MAX) != 0)
61 			return NULL;
62 	}
63 
64 	if (rte_ml_dev_pmd_get_named_dev(name) != NULL) {
65 		RTE_MLDEV_LOG(ERR, "ML device with name %s already allocated!", name);
66 		return NULL;
67 	}
68 
69 	/* Get a free device ID */
70 	for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) {
71 		dev = rte_ml_dev_pmd_get_dev(dev_id);
72 		if (dev->attached == ML_DEV_DETACHED)
73 			break;
74 	}
75 
76 	if (dev_id == ml_dev_globals.max_devs) {
77 		RTE_MLDEV_LOG(ERR, "Reached maximum number of ML devices");
78 		return NULL;
79 	}
80 
81 	if (dev->data == NULL) {
82 		/* Reserve memzone name */
83 		sprintf(mz_name, "rte_ml_dev_data_%d", dev_id);
84 		if (rte_eal_process_type() == RTE_PROC_PRIMARY) {
85 			mz = rte_memzone_reserve(mz_name, sizeof(struct rte_ml_dev_data), socket_id,
86 						 0);
87 			RTE_MLDEV_LOG(DEBUG, "PRIMARY: reserved memzone for %s (%p)", mz_name, mz);
88 		} else {
89 			mz = rte_memzone_lookup(mz_name);
90 			RTE_MLDEV_LOG(DEBUG, "SECONDARY: looked up memzone for %s (%p)", mz_name,
91 				      mz);
92 		}
93 
94 		if (mz == NULL)
95 			return NULL;
96 
97 		ml_dev_globals.data[dev_id] = mz->addr;
98 		if (rte_eal_process_type() == RTE_PROC_PRIMARY)
99 			memset(ml_dev_globals.data[dev_id], 0, sizeof(struct rte_ml_dev_data));
100 
101 		dev->data = ml_dev_globals.data[dev_id];
102 		if (rte_eal_process_type() == RTE_PROC_PRIMARY) {
103 			strlcpy(dev->data->name, name, RTE_ML_STR_MAX);
104 			dev->data->dev_id = dev_id;
105 			dev->data->socket_id = socket_id;
106 			dev->data->dev_started = 0;
107 			RTE_MLDEV_LOG(DEBUG, "PRIMARY: init mldev data");
108 		}
109 
110 		RTE_MLDEV_LOG(DEBUG, "Data for %s: dev_id %d, socket %u", dev->data->name,
111 			      dev->data->dev_id, dev->data->socket_id);
112 
113 		dev->attached = ML_DEV_ATTACHED;
114 		ml_dev_globals.nb_devs++;
115 	}
116 
117 	dev->enqueue_burst = NULL;
118 	dev->dequeue_burst = NULL;
119 
120 	return dev;
121 }
122 
123 int
124 rte_ml_dev_pmd_release(struct rte_ml_dev *dev)
125 {
126 	char mz_name[RTE_MEMZONE_NAMESIZE];
127 	const struct rte_memzone *mz;
128 	int16_t dev_id;
129 	int ret = 0;
130 
131 	if (dev == NULL)
132 		return -EINVAL;
133 
134 	dev_id = dev->data->dev_id;
135 
136 	/* Memzone lookup */
137 	sprintf(mz_name, "rte_ml_dev_data_%d", dev_id);
138 	mz = rte_memzone_lookup(mz_name);
139 	if (mz == NULL)
140 		return -ENOMEM;
141 
142 	RTE_ASSERT(ml_dev_globals.data[dev_id] == mz->addr);
143 	ml_dev_globals.data[dev_id] = NULL;
144 
145 	if (rte_eal_process_type() == RTE_PROC_PRIMARY) {
146 		RTE_MLDEV_LOG(DEBUG, "PRIMARY: free memzone of %s (%p)", mz_name, mz);
147 		ret = rte_memzone_free(mz);
148 	} else {
149 		RTE_MLDEV_LOG(DEBUG, "SECONDARY: don't free memzone of %s (%p)", mz_name, mz);
150 	}
151 
152 	dev->attached = ML_DEV_DETACHED;
153 	ml_dev_globals.nb_devs--;
154 
155 	return ret;
156 }
157 
158 int
159 rte_ml_dev_init(size_t dev_max)
160 {
161 	if (dev_max == 0 || dev_max > INT16_MAX) {
162 		RTE_MLDEV_LOG(ERR, "Invalid dev_max = %zu (> %d)", dev_max, INT16_MAX);
163 		rte_errno = EINVAL;
164 		return -rte_errno;
165 	}
166 
167 	/* No lock, it must be called before or during first probing. */
168 	if (ml_dev_globals.devs != NULL) {
169 		RTE_MLDEV_LOG(ERR, "Device array already initialized");
170 		rte_errno = EBUSY;
171 		return -rte_errno;
172 	}
173 
174 	ml_dev_globals.devs = calloc(dev_max, sizeof(struct rte_ml_dev));
175 	if (ml_dev_globals.devs == NULL) {
176 		RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library");
177 		rte_errno = ENOMEM;
178 		return -rte_errno;
179 	}
180 
181 	ml_dev_globals.data = calloc(dev_max, sizeof(struct rte_ml_dev_data *));
182 	if (ml_dev_globals.data == NULL) {
183 		RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library");
184 		rte_errno = ENOMEM;
185 		return -rte_errno;
186 	}
187 
188 	ml_dev_globals.max_devs = dev_max;
189 
190 	return 0;
191 }
192 
193 uint16_t
194 rte_ml_dev_count(void)
195 {
196 	return ml_dev_globals.nb_devs;
197 }
198 
199 int
200 rte_ml_dev_is_valid_dev(int16_t dev_id)
201 {
202 	struct rte_ml_dev *dev = NULL;
203 
204 	if (dev_id >= ml_dev_globals.max_devs || ml_dev_globals.devs[dev_id].data == NULL)
205 		return 0;
206 
207 	dev = rte_ml_dev_pmd_get_dev(dev_id);
208 	if (dev->attached != ML_DEV_ATTACHED)
209 		return 0;
210 	else
211 		return 1;
212 }
213 
214 int
215 rte_ml_dev_socket_id(int16_t dev_id)
216 {
217 	struct rte_ml_dev *dev;
218 
219 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
220 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
221 		return -EINVAL;
222 	}
223 
224 	dev = rte_ml_dev_pmd_get_dev(dev_id);
225 
226 	return dev->data->socket_id;
227 }
228 
229 int
230 rte_ml_dev_info_get(int16_t dev_id, struct rte_ml_dev_info *dev_info)
231 {
232 	struct rte_ml_dev *dev;
233 
234 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
235 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
236 		return -EINVAL;
237 	}
238 
239 	dev = rte_ml_dev_pmd_get_dev(dev_id);
240 	if (*dev->dev_ops->dev_info_get == NULL)
241 		return -ENOTSUP;
242 
243 	if (dev_info == NULL) {
244 		RTE_MLDEV_LOG(ERR, "Dev %d, dev_info cannot be NULL", dev_id);
245 		return -EINVAL;
246 	}
247 	memset(dev_info, 0, sizeof(struct rte_ml_dev_info));
248 
249 	return (*dev->dev_ops->dev_info_get)(dev, dev_info);
250 }
251 
252 int
253 rte_ml_dev_configure(int16_t dev_id, const struct rte_ml_dev_config *config)
254 {
255 	struct rte_ml_dev_info dev_info;
256 	struct rte_ml_dev *dev;
257 	int ret;
258 
259 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
260 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
261 		return -EINVAL;
262 	}
263 
264 	dev = rte_ml_dev_pmd_get_dev(dev_id);
265 	if (*dev->dev_ops->dev_configure == NULL)
266 		return -ENOTSUP;
267 
268 	if (dev->data->dev_started) {
269 		RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id);
270 		return -EBUSY;
271 	}
272 
273 	if (config == NULL) {
274 		RTE_MLDEV_LOG(ERR, "Dev %d, config cannot be NULL", dev_id);
275 		return -EINVAL;
276 	}
277 
278 	ret = rte_ml_dev_info_get(dev_id, &dev_info);
279 	if (ret < 0)
280 		return ret;
281 
282 	if (config->nb_queue_pairs > dev_info.max_queue_pairs) {
283 		RTE_MLDEV_LOG(ERR, "Device %d num of queues %u > %u", dev_id,
284 			      config->nb_queue_pairs, dev_info.max_queue_pairs);
285 		return -EINVAL;
286 	}
287 
288 	return (*dev->dev_ops->dev_configure)(dev, config);
289 }
290 
291 int
292 rte_ml_dev_close(int16_t dev_id)
293 {
294 	struct rte_ml_dev *dev;
295 
296 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
297 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
298 		return -EINVAL;
299 	}
300 
301 	dev = rte_ml_dev_pmd_get_dev(dev_id);
302 	if (*dev->dev_ops->dev_close == NULL)
303 		return -ENOTSUP;
304 
305 	/* Device must be stopped before it can be closed */
306 	if (dev->data->dev_started == 1) {
307 		RTE_MLDEV_LOG(ERR, "Device %d must be stopped before closing", dev_id);
308 		return -EBUSY;
309 	}
310 
311 	return (*dev->dev_ops->dev_close)(dev);
312 }
313 
314 int
315 rte_ml_dev_start(int16_t dev_id)
316 {
317 	struct rte_ml_dev *dev;
318 	int ret;
319 
320 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
321 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
322 		return -EINVAL;
323 	}
324 
325 	dev = rte_ml_dev_pmd_get_dev(dev_id);
326 	if (*dev->dev_ops->dev_start == NULL)
327 		return -ENOTSUP;
328 
329 	if (dev->data->dev_started != 0) {
330 		RTE_MLDEV_LOG(ERR, "Device %d is already started", dev_id);
331 		return -EBUSY;
332 	}
333 
334 	ret = (*dev->dev_ops->dev_start)(dev);
335 	if (ret == 0)
336 		dev->data->dev_started = 1;
337 
338 	return ret;
339 }
340 
341 int
342 rte_ml_dev_stop(int16_t dev_id)
343 {
344 	struct rte_ml_dev *dev;
345 	int ret;
346 
347 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
348 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
349 		return -EINVAL;
350 	}
351 
352 	dev = rte_ml_dev_pmd_get_dev(dev_id);
353 	if (*dev->dev_ops->dev_stop == NULL)
354 		return -ENOTSUP;
355 
356 	if (dev->data->dev_started == 0) {
357 		RTE_MLDEV_LOG(ERR, "Device %d is not started", dev_id);
358 		return -EBUSY;
359 	}
360 
361 	ret = (*dev->dev_ops->dev_stop)(dev);
362 	if (ret == 0)
363 		dev->data->dev_started = 0;
364 
365 	return ret;
366 }
367 
368 uint16_t
369 rte_ml_dev_queue_pair_count(int16_t dev_id)
370 {
371 	struct rte_ml_dev *dev;
372 
373 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
374 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
375 		return -EINVAL;
376 	}
377 
378 	dev = rte_ml_dev_pmd_get_dev(dev_id);
379 
380 	return dev->data->nb_queue_pairs;
381 }
382 
383 int
384 rte_ml_dev_queue_pair_setup(int16_t dev_id, uint16_t queue_pair_id,
385 			    const struct rte_ml_dev_qp_conf *qp_conf, int socket_id)
386 {
387 	struct rte_ml_dev *dev;
388 
389 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
390 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
391 		return -EINVAL;
392 	}
393 
394 	dev = rte_ml_dev_pmd_get_dev(dev_id);
395 	if (*dev->dev_ops->dev_queue_pair_setup == NULL)
396 		return -ENOTSUP;
397 
398 	if (queue_pair_id >= dev->data->nb_queue_pairs) {
399 		RTE_MLDEV_LOG(ERR, "Invalid queue_pair_id = %d", queue_pair_id);
400 		return -EINVAL;
401 	}
402 
403 	if (qp_conf == NULL) {
404 		RTE_MLDEV_LOG(ERR, "Dev %d, qp_conf cannot be NULL", dev_id);
405 		return -EINVAL;
406 	}
407 
408 	if (dev->data->dev_started) {
409 		RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id);
410 		return -EBUSY;
411 	}
412 
413 	return (*dev->dev_ops->dev_queue_pair_setup)(dev, queue_pair_id, qp_conf, socket_id);
414 }
415 
416 int
417 rte_ml_dev_stats_get(int16_t dev_id, struct rte_ml_dev_stats *stats)
418 {
419 	struct rte_ml_dev *dev;
420 
421 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
422 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
423 		return -EINVAL;
424 	}
425 
426 	dev = rte_ml_dev_pmd_get_dev(dev_id);
427 	if (*dev->dev_ops->dev_stats_get == NULL)
428 		return -ENOTSUP;
429 
430 	if (stats == NULL) {
431 		RTE_MLDEV_LOG(ERR, "Dev %d, stats cannot be NULL", dev_id);
432 		return -EINVAL;
433 	}
434 	memset(stats, 0, sizeof(struct rte_ml_dev_stats));
435 
436 	return (*dev->dev_ops->dev_stats_get)(dev, stats);
437 }
438 
439 void
440 rte_ml_dev_stats_reset(int16_t dev_id)
441 {
442 	struct rte_ml_dev *dev;
443 
444 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
445 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
446 		return;
447 	}
448 
449 	dev = rte_ml_dev_pmd_get_dev(dev_id);
450 	if (*dev->dev_ops->dev_stats_reset == NULL)
451 		return;
452 
453 	(*dev->dev_ops->dev_stats_reset)(dev);
454 }
455 
456 int
457 rte_ml_dev_xstats_names_get(int16_t dev_id, enum rte_ml_dev_xstats_mode mode, int32_t model_id,
458 			    struct rte_ml_dev_xstats_map *xstats_map, uint32_t size)
459 {
460 	struct rte_ml_dev *dev;
461 
462 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
463 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
464 		return -EINVAL;
465 	}
466 
467 	dev = rte_ml_dev_pmd_get_dev(dev_id);
468 	if (*dev->dev_ops->dev_xstats_names_get == NULL)
469 		return -ENOTSUP;
470 
471 	return (*dev->dev_ops->dev_xstats_names_get)(dev, mode, model_id, xstats_map, size);
472 }
473 
474 int
475 rte_ml_dev_xstats_by_name_get(int16_t dev_id, const char *name, uint16_t *stat_id, uint64_t *value)
476 {
477 	struct rte_ml_dev *dev;
478 
479 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
480 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
481 		return -EINVAL;
482 	}
483 
484 	dev = rte_ml_dev_pmd_get_dev(dev_id);
485 	if (*dev->dev_ops->dev_xstats_by_name_get == NULL)
486 		return -ENOTSUP;
487 
488 	if (name == NULL) {
489 		RTE_MLDEV_LOG(ERR, "Dev %d, name cannot be NULL", dev_id);
490 		return -EINVAL;
491 	}
492 
493 	if (value == NULL) {
494 		RTE_MLDEV_LOG(ERR, "Dev %d, value cannot be NULL", dev_id);
495 		return -EINVAL;
496 	}
497 
498 	return (*dev->dev_ops->dev_xstats_by_name_get)(dev, name, stat_id, value);
499 }
500 
501 int
502 rte_ml_dev_xstats_get(int16_t dev_id, enum rte_ml_dev_xstats_mode mode, int32_t model_id,
503 		      const uint16_t stat_ids[], uint64_t values[], uint16_t nb_ids)
504 {
505 	struct rte_ml_dev *dev;
506 
507 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
508 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
509 		return -EINVAL;
510 	}
511 
512 	dev = rte_ml_dev_pmd_get_dev(dev_id);
513 	if (*dev->dev_ops->dev_xstats_get == NULL)
514 		return -ENOTSUP;
515 
516 	if (stat_ids == NULL) {
517 		RTE_MLDEV_LOG(ERR, "Dev %d, stat_ids cannot be NULL", dev_id);
518 		return -EINVAL;
519 	}
520 
521 	if (values == NULL) {
522 		RTE_MLDEV_LOG(ERR, "Dev %d, values cannot be NULL", dev_id);
523 		return -EINVAL;
524 	}
525 
526 	return (*dev->dev_ops->dev_xstats_get)(dev, mode, model_id, stat_ids, values, nb_ids);
527 }
528 
529 int
530 rte_ml_dev_xstats_reset(int16_t dev_id, enum rte_ml_dev_xstats_mode mode, int32_t model_id,
531 			const uint16_t stat_ids[], uint16_t nb_ids)
532 {
533 	struct rte_ml_dev *dev;
534 
535 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
536 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
537 		return -EINVAL;
538 	}
539 
540 	dev = rte_ml_dev_pmd_get_dev(dev_id);
541 	if (*dev->dev_ops->dev_xstats_reset == NULL)
542 		return -ENOTSUP;
543 
544 	return (*dev->dev_ops->dev_xstats_reset)(dev, mode, model_id, stat_ids, nb_ids);
545 }
546 
547 int
548 rte_ml_dev_dump(int16_t dev_id, FILE *fd)
549 {
550 	struct rte_ml_dev *dev;
551 
552 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
553 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
554 		return -EINVAL;
555 	}
556 
557 	dev = rte_ml_dev_pmd_get_dev(dev_id);
558 	if (*dev->dev_ops->dev_dump == NULL)
559 		return -ENOTSUP;
560 
561 	if (fd == NULL) {
562 		RTE_MLDEV_LOG(ERR, "Dev %d, file descriptor cannot be NULL", dev_id);
563 		return -EINVAL;
564 	}
565 
566 	return (*dev->dev_ops->dev_dump)(dev, fd);
567 }
568 
569 int
570 rte_ml_dev_selftest(int16_t dev_id)
571 {
572 	struct rte_ml_dev *dev;
573 
574 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
575 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
576 		return -EINVAL;
577 	}
578 
579 	dev = rte_ml_dev_pmd_get_dev(dev_id);
580 	if (*dev->dev_ops->dev_selftest == NULL)
581 		return -ENOTSUP;
582 
583 	return (*dev->dev_ops->dev_selftest)(dev);
584 }
585 
586 int
587 rte_ml_model_load(int16_t dev_id, struct rte_ml_model_params *params, uint16_t *model_id)
588 {
589 	struct rte_ml_dev *dev;
590 
591 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
592 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
593 		return -EINVAL;
594 	}
595 
596 	dev = rte_ml_dev_pmd_get_dev(dev_id);
597 	if (*dev->dev_ops->model_load == NULL)
598 		return -ENOTSUP;
599 
600 	if (params == NULL) {
601 		RTE_MLDEV_LOG(ERR, "Dev %d, params cannot be NULL", dev_id);
602 		return -EINVAL;
603 	}
604 
605 	if (model_id == NULL) {
606 		RTE_MLDEV_LOG(ERR, "Dev %d, model_id cannot be NULL", dev_id);
607 		return -EINVAL;
608 	}
609 
610 	return (*dev->dev_ops->model_load)(dev, params, model_id);
611 }
612 
613 int
614 rte_ml_model_unload(int16_t dev_id, uint16_t model_id)
615 {
616 	struct rte_ml_dev *dev;
617 
618 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
619 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
620 		return -EINVAL;
621 	}
622 
623 	dev = rte_ml_dev_pmd_get_dev(dev_id);
624 	if (*dev->dev_ops->model_unload == NULL)
625 		return -ENOTSUP;
626 
627 	return (*dev->dev_ops->model_unload)(dev, model_id);
628 }
629 
630 int
631 rte_ml_model_start(int16_t dev_id, uint16_t model_id)
632 {
633 	struct rte_ml_dev *dev;
634 
635 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
636 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
637 		return -EINVAL;
638 	}
639 
640 	dev = rte_ml_dev_pmd_get_dev(dev_id);
641 	if (*dev->dev_ops->model_start == NULL)
642 		return -ENOTSUP;
643 
644 	return (*dev->dev_ops->model_start)(dev, model_id);
645 }
646 
647 int
648 rte_ml_model_stop(int16_t dev_id, uint16_t model_id)
649 {
650 	struct rte_ml_dev *dev;
651 
652 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
653 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
654 		return -EINVAL;
655 	}
656 
657 	dev = rte_ml_dev_pmd_get_dev(dev_id);
658 	if (*dev->dev_ops->model_stop == NULL)
659 		return -ENOTSUP;
660 
661 	return (*dev->dev_ops->model_stop)(dev, model_id);
662 }
663 
664 int
665 rte_ml_model_info_get(int16_t dev_id, uint16_t model_id, struct rte_ml_model_info *model_info)
666 {
667 	struct rte_ml_dev *dev;
668 
669 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
670 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
671 		return -EINVAL;
672 	}
673 
674 	dev = rte_ml_dev_pmd_get_dev(dev_id);
675 	if (*dev->dev_ops->model_info_get == NULL)
676 		return -ENOTSUP;
677 
678 	if (model_info == NULL) {
679 		RTE_MLDEV_LOG(ERR, "Dev %d, model_id %u, model_info cannot be NULL", dev_id,
680 			      model_id);
681 		return -EINVAL;
682 	}
683 
684 	return (*dev->dev_ops->model_info_get)(dev, model_id, model_info);
685 }
686 
687 int
688 rte_ml_model_params_update(int16_t dev_id, uint16_t model_id, void *buffer)
689 {
690 	struct rte_ml_dev *dev;
691 
692 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
693 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
694 		return -EINVAL;
695 	}
696 
697 	dev = rte_ml_dev_pmd_get_dev(dev_id);
698 	if (*dev->dev_ops->model_params_update == NULL)
699 		return -ENOTSUP;
700 
701 	if (buffer == NULL) {
702 		RTE_MLDEV_LOG(ERR, "Dev %d, buffer cannot be NULL", dev_id);
703 		return -EINVAL;
704 	}
705 
706 	return (*dev->dev_ops->model_params_update)(dev, model_id, buffer);
707 }
708 
709 int
710 rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, struct rte_ml_buff_seg **dbuffer,
711 		   struct rte_ml_buff_seg **qbuffer)
712 {
713 	struct rte_ml_dev *dev;
714 
715 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
716 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
717 		return -EINVAL;
718 	}
719 
720 	dev = rte_ml_dev_pmd_get_dev(dev_id);
721 	if (*dev->dev_ops->io_quantize == NULL)
722 		return -ENOTSUP;
723 
724 	if (dbuffer == NULL) {
725 		RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL", dev_id);
726 		return -EINVAL;
727 	}
728 
729 	if (qbuffer == NULL) {
730 		RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL", dev_id);
731 		return -EINVAL;
732 	}
733 
734 	return (*dev->dev_ops->io_quantize)(dev, model_id, dbuffer, qbuffer);
735 }
736 
737 int
738 rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, struct rte_ml_buff_seg **qbuffer,
739 		     struct rte_ml_buff_seg **dbuffer)
740 {
741 	struct rte_ml_dev *dev;
742 
743 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
744 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
745 		return -EINVAL;
746 	}
747 
748 	dev = rte_ml_dev_pmd_get_dev(dev_id);
749 	if (*dev->dev_ops->io_dequantize == NULL)
750 		return -ENOTSUP;
751 
752 	if (qbuffer == NULL) {
753 		RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL", dev_id);
754 		return -EINVAL;
755 	}
756 
757 	if (dbuffer == NULL) {
758 		RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL", dev_id);
759 		return -EINVAL;
760 	}
761 
762 	return (*dev->dev_ops->io_dequantize)(dev, model_id, qbuffer, dbuffer);
763 }
764 
765 /** Initialise rte_ml_op mempool element */
766 static void
767 ml_op_init(struct rte_mempool *mempool, __rte_unused void *opaque_arg, void *_op_data,
768 	   __rte_unused unsigned int i)
769 {
770 	struct rte_ml_op *op = _op_data;
771 
772 	memset(_op_data, 0, mempool->elt_size);
773 	op->status = RTE_ML_OP_STATUS_NOT_PROCESSED;
774 	op->mempool = mempool;
775 }
776 
777 struct rte_mempool *
778 rte_ml_op_pool_create(const char *name, unsigned int nb_elts, unsigned int cache_size,
779 		      uint16_t user_size, int socket_id)
780 {
781 	struct rte_ml_op_pool_private *priv;
782 	struct rte_mempool *mp;
783 	unsigned int elt_size;
784 
785 	/* lookup mempool in case already allocated */
786 	mp = rte_mempool_lookup(name);
787 	elt_size = sizeof(struct rte_ml_op) + user_size;
788 
789 	if (mp != NULL) {
790 		priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp);
791 		if (mp->elt_size != elt_size || mp->cache_size < cache_size || mp->size < nb_elts ||
792 		    priv->user_size < user_size) {
793 			mp = NULL;
794 			RTE_MLDEV_LOG(ERR,
795 				      "Mempool %s already exists but with incompatible parameters",
796 				      name);
797 			return NULL;
798 		}
799 		return mp;
800 	}
801 
802 	mp = rte_mempool_create(name, nb_elts, elt_size, cache_size,
803 				sizeof(struct rte_ml_op_pool_private), NULL, NULL, ml_op_init, NULL,
804 				socket_id, 0);
805 	if (mp == NULL) {
806 		RTE_MLDEV_LOG(ERR, "Failed to create mempool %s", name);
807 		return NULL;
808 	}
809 
810 	priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp);
811 	priv->user_size = user_size;
812 
813 	return mp;
814 }
815 
816 void
817 rte_ml_op_pool_free(struct rte_mempool *mempool)
818 {
819 	rte_mempool_free(mempool);
820 }
821 
822 uint16_t
823 rte_ml_enqueue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops)
824 {
825 	struct rte_ml_dev *dev;
826 
827 #ifdef RTE_LIBRTE_ML_DEV_DEBUG
828 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
829 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
830 		rte_errno = -EINVAL;
831 		return 0;
832 	}
833 
834 	dev = rte_ml_dev_pmd_get_dev(dev_id);
835 	if (*dev->enqueue_burst == NULL) {
836 		rte_errno = -ENOTSUP;
837 		return 0;
838 	}
839 
840 	if (ops == NULL) {
841 		RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL", dev_id);
842 		rte_errno = -EINVAL;
843 		return 0;
844 	}
845 
846 	if (qp_id >= dev->data->nb_queue_pairs) {
847 		RTE_MLDEV_LOG(ERR, "Invalid qp_id %u", qp_id);
848 		rte_errno = -EINVAL;
849 		return 0;
850 	}
851 #else
852 	dev = rte_ml_dev_pmd_get_dev(dev_id);
853 #endif
854 
855 	return (*dev->enqueue_burst)(dev, qp_id, ops, nb_ops);
856 }
857 
858 uint16_t
859 rte_ml_dequeue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops)
860 {
861 	struct rte_ml_dev *dev;
862 
863 #ifdef RTE_LIBRTE_ML_DEV_DEBUG
864 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
865 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
866 		rte_errno = -EINVAL;
867 		return 0;
868 	}
869 
870 	dev = rte_ml_dev_pmd_get_dev(dev_id);
871 	if (*dev->dequeue_burst == NULL) {
872 		rte_errno = -ENOTSUP;
873 		return 0;
874 	}
875 
876 	if (ops == NULL) {
877 		RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL", dev_id);
878 		rte_errno = -EINVAL;
879 		return 0;
880 	}
881 
882 	if (qp_id >= dev->data->nb_queue_pairs) {
883 		RTE_MLDEV_LOG(ERR, "Invalid qp_id %u", qp_id);
884 		rte_errno = -EINVAL;
885 		return 0;
886 	}
887 #else
888 	dev = rte_ml_dev_pmd_get_dev(dev_id);
889 #endif
890 
891 	return (*dev->dequeue_burst)(dev, qp_id, ops, nb_ops);
892 }
893 
894 int
895 rte_ml_op_error_get(int16_t dev_id, struct rte_ml_op *op, struct rte_ml_op_error *error)
896 {
897 	struct rte_ml_dev *dev;
898 
899 #ifdef RTE_LIBRTE_ML_DEV_DEBUG
900 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
901 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
902 		return -EINVAL;
903 	}
904 
905 	dev = rte_ml_dev_pmd_get_dev(dev_id);
906 	if (*dev->op_error_get == NULL)
907 		return -ENOTSUP;
908 
909 	if (op == NULL) {
910 		RTE_MLDEV_LOG(ERR, "Dev %d, op cannot be NULL", dev_id);
911 		return -EINVAL;
912 	}
913 
914 	if (error == NULL) {
915 		RTE_MLDEV_LOG(ERR, "Dev %d, error cannot be NULL", dev_id);
916 		return -EINVAL;
917 	}
918 #else
919 	dev = rte_ml_dev_pmd_get_dev(dev_id);
920 #endif
921 
922 	return (*dev->op_error_get)(dev, op, error);
923 }
924 
925 RTE_LOG_REGISTER_DEFAULT(rte_ml_dev_logtype, INFO);
926