xref: /dpdk/lib/mldev/rte_mldev.c (revision de1f01a8eabd1da08d85e77ff99ba85e03cfd1ad)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2022 Marvell.
3  */
4 
5 #include <rte_errno.h>
6 #include <rte_log.h>
7 #include <rte_mldev.h>
8 #include <rte_mldev_pmd.h>
9 
10 #include <stdlib.h>
11 
12 static struct rte_ml_dev_global ml_dev_globals = {
13 	.devs = NULL, .data = NULL, .nb_devs = 0, .max_devs = RTE_MLDEV_DEFAULT_MAX};
14 
15 /*
16  * Private data structure of an operation pool.
17  *
18  * A structure that contains ml op_pool specific data that is
19  * appended after the mempool structure (in private data).
20  */
21 struct rte_ml_op_pool_private {
22 	uint16_t user_size;
23 	/*< Size of private user data with each operation. */
24 };
25 
26 struct rte_ml_dev *
27 rte_ml_dev_pmd_get_dev(int16_t dev_id)
28 {
29 	return &ml_dev_globals.devs[dev_id];
30 }
31 
32 struct rte_ml_dev *
33 rte_ml_dev_pmd_get_named_dev(const char *name)
34 {
35 	struct rte_ml_dev *dev;
36 	int16_t dev_id;
37 
38 	if (name == NULL)
39 		return NULL;
40 
41 	for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) {
42 		dev = rte_ml_dev_pmd_get_dev(dev_id);
43 		if ((dev->attached == ML_DEV_ATTACHED) && (strcmp(dev->data->name, name) == 0))
44 			return dev;
45 	}
46 
47 	return NULL;
48 }
49 
50 struct rte_ml_dev *
51 rte_ml_dev_pmd_allocate(const char *name, uint8_t socket_id)
52 {
53 	char mz_name[RTE_MEMZONE_NAMESIZE];
54 	const struct rte_memzone *mz;
55 	struct rte_ml_dev *dev;
56 	int16_t dev_id;
57 
58 	/* implicit initialization of library before adding first device */
59 	if (ml_dev_globals.devs == NULL) {
60 		if (rte_ml_dev_init(RTE_MLDEV_DEFAULT_MAX) != 0)
61 			return NULL;
62 	}
63 
64 	if (rte_ml_dev_pmd_get_named_dev(name) != NULL) {
65 		RTE_MLDEV_LOG(ERR, "ML device with name %s already allocated!", name);
66 		return NULL;
67 	}
68 
69 	/* Get a free device ID */
70 	for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) {
71 		dev = rte_ml_dev_pmd_get_dev(dev_id);
72 		if (dev->attached == ML_DEV_DETACHED)
73 			break;
74 	}
75 
76 	if (dev_id == ml_dev_globals.max_devs) {
77 		RTE_MLDEV_LOG(ERR, "Reached maximum number of ML devices");
78 		return NULL;
79 	}
80 
81 	if (dev->data == NULL) {
82 		/* Reserve memzone name */
83 		sprintf(mz_name, "rte_ml_dev_data_%d", dev_id);
84 		if (rte_eal_process_type() == RTE_PROC_PRIMARY) {
85 			mz = rte_memzone_reserve(mz_name, sizeof(struct rte_ml_dev_data), socket_id,
86 						 0);
87 			RTE_MLDEV_LOG(DEBUG, "PRIMARY: reserved memzone for %s (%p)", mz_name, mz);
88 		} else {
89 			mz = rte_memzone_lookup(mz_name);
90 			RTE_MLDEV_LOG(DEBUG, "SECONDARY: looked up memzone for %s (%p)", mz_name,
91 				      mz);
92 		}
93 
94 		if (mz == NULL)
95 			return NULL;
96 
97 		ml_dev_globals.data[dev_id] = mz->addr;
98 		if (rte_eal_process_type() == RTE_PROC_PRIMARY)
99 			memset(ml_dev_globals.data[dev_id], 0, sizeof(struct rte_ml_dev_data));
100 
101 		dev->data = ml_dev_globals.data[dev_id];
102 		if (rte_eal_process_type() == RTE_PROC_PRIMARY) {
103 			strlcpy(dev->data->name, name, RTE_ML_STR_MAX);
104 			dev->data->dev_id = dev_id;
105 			dev->data->socket_id = socket_id;
106 			dev->data->dev_started = 0;
107 			RTE_MLDEV_LOG(DEBUG, "PRIMARY: init mldev data");
108 		}
109 
110 		RTE_MLDEV_LOG(DEBUG, "Data for %s: dev_id %d, socket %u", dev->data->name,
111 			      dev->data->dev_id, dev->data->socket_id);
112 
113 		dev->attached = ML_DEV_ATTACHED;
114 		ml_dev_globals.nb_devs++;
115 	}
116 
117 	dev->enqueue_burst = NULL;
118 	dev->dequeue_burst = NULL;
119 
120 	return dev;
121 }
122 
123 int
124 rte_ml_dev_pmd_release(struct rte_ml_dev *dev)
125 {
126 	char mz_name[RTE_MEMZONE_NAMESIZE];
127 	const struct rte_memzone *mz;
128 	int16_t dev_id;
129 	int ret = 0;
130 
131 	if (dev == NULL)
132 		return -EINVAL;
133 
134 	dev_id = dev->data->dev_id;
135 
136 	/* Memzone lookup */
137 	sprintf(mz_name, "rte_ml_dev_data_%d", dev_id);
138 	mz = rte_memzone_lookup(mz_name);
139 	if (mz == NULL)
140 		return -ENOMEM;
141 
142 	RTE_ASSERT(ml_dev_globals.data[dev_id] == mz->addr);
143 	ml_dev_globals.data[dev_id] = NULL;
144 
145 	if (rte_eal_process_type() == RTE_PROC_PRIMARY) {
146 		RTE_MLDEV_LOG(DEBUG, "PRIMARY: free memzone of %s (%p)", mz_name, mz);
147 		ret = rte_memzone_free(mz);
148 	} else {
149 		RTE_MLDEV_LOG(DEBUG, "SECONDARY: don't free memzone of %s (%p)", mz_name, mz);
150 	}
151 
152 	dev->attached = ML_DEV_DETACHED;
153 	ml_dev_globals.nb_devs--;
154 
155 	return ret;
156 }
157 
158 int
159 rte_ml_dev_init(size_t dev_max)
160 {
161 	if (dev_max == 0 || dev_max > INT16_MAX) {
162 		RTE_MLDEV_LOG(ERR, "Invalid dev_max = %zu (> %d)", dev_max, INT16_MAX);
163 		rte_errno = EINVAL;
164 		return -rte_errno;
165 	}
166 
167 	/* No lock, it must be called before or during first probing. */
168 	if (ml_dev_globals.devs != NULL) {
169 		RTE_MLDEV_LOG(ERR, "Device array already initialized");
170 		rte_errno = EBUSY;
171 		return -rte_errno;
172 	}
173 
174 	ml_dev_globals.devs = calloc(dev_max, sizeof(struct rte_ml_dev));
175 	if (ml_dev_globals.devs == NULL) {
176 		RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library");
177 		rte_errno = ENOMEM;
178 		return -rte_errno;
179 	}
180 
181 	ml_dev_globals.data = calloc(dev_max, sizeof(struct rte_ml_dev_data *));
182 	if (ml_dev_globals.data == NULL) {
183 		RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library");
184 		rte_errno = ENOMEM;
185 		return -rte_errno;
186 	}
187 
188 	ml_dev_globals.max_devs = dev_max;
189 
190 	return 0;
191 }
192 
193 uint16_t
194 rte_ml_dev_count(void)
195 {
196 	return ml_dev_globals.nb_devs;
197 }
198 
199 int
200 rte_ml_dev_is_valid_dev(int16_t dev_id)
201 {
202 	struct rte_ml_dev *dev = NULL;
203 
204 	if (dev_id >= ml_dev_globals.max_devs || ml_dev_globals.devs[dev_id].data == NULL)
205 		return 0;
206 
207 	dev = rte_ml_dev_pmd_get_dev(dev_id);
208 	if (dev->attached != ML_DEV_ATTACHED)
209 		return 0;
210 	else
211 		return 1;
212 }
213 
214 int
215 rte_ml_dev_socket_id(int16_t dev_id)
216 {
217 	struct rte_ml_dev *dev;
218 
219 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
220 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
221 		return -EINVAL;
222 	}
223 
224 	dev = rte_ml_dev_pmd_get_dev(dev_id);
225 
226 	return dev->data->socket_id;
227 }
228 
229 int
230 rte_ml_dev_info_get(int16_t dev_id, struct rte_ml_dev_info *dev_info)
231 {
232 	struct rte_ml_dev *dev;
233 
234 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
235 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
236 		return -EINVAL;
237 	}
238 
239 	dev = rte_ml_dev_pmd_get_dev(dev_id);
240 	if (*dev->dev_ops->dev_info_get == NULL)
241 		return -ENOTSUP;
242 
243 	if (dev_info == NULL) {
244 		RTE_MLDEV_LOG(ERR, "Dev %d, dev_info cannot be NULL", dev_id);
245 		return -EINVAL;
246 	}
247 	memset(dev_info, 0, sizeof(struct rte_ml_dev_info));
248 
249 	return (*dev->dev_ops->dev_info_get)(dev, dev_info);
250 }
251 
252 int
253 rte_ml_dev_configure(int16_t dev_id, const struct rte_ml_dev_config *config)
254 {
255 	struct rte_ml_dev_info dev_info;
256 	struct rte_ml_dev *dev;
257 	int ret;
258 
259 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
260 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
261 		return -EINVAL;
262 	}
263 
264 	dev = rte_ml_dev_pmd_get_dev(dev_id);
265 	if (*dev->dev_ops->dev_configure == NULL)
266 		return -ENOTSUP;
267 
268 	if (dev->data->dev_started) {
269 		RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id);
270 		return -EBUSY;
271 	}
272 
273 	if (config == NULL) {
274 		RTE_MLDEV_LOG(ERR, "Dev %d, config cannot be NULL", dev_id);
275 		return -EINVAL;
276 	}
277 
278 	ret = rte_ml_dev_info_get(dev_id, &dev_info);
279 	if (ret < 0)
280 		return ret;
281 
282 	if (config->nb_queue_pairs > dev_info.max_queue_pairs) {
283 		RTE_MLDEV_LOG(ERR, "Device %d num of queues %u > %u", dev_id,
284 			      config->nb_queue_pairs, dev_info.max_queue_pairs);
285 		return -EINVAL;
286 	}
287 
288 	return (*dev->dev_ops->dev_configure)(dev, config);
289 }
290 
291 int
292 rte_ml_dev_close(int16_t dev_id)
293 {
294 	struct rte_ml_dev *dev;
295 
296 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
297 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
298 		return -EINVAL;
299 	}
300 
301 	dev = rte_ml_dev_pmd_get_dev(dev_id);
302 	if (*dev->dev_ops->dev_close == NULL)
303 		return -ENOTSUP;
304 
305 	/* Device must be stopped before it can be closed */
306 	if (dev->data->dev_started == 1) {
307 		RTE_MLDEV_LOG(ERR, "Device %d must be stopped before closing", dev_id);
308 		return -EBUSY;
309 	}
310 
311 	return (*dev->dev_ops->dev_close)(dev);
312 }
313 
314 int
315 rte_ml_dev_start(int16_t dev_id)
316 {
317 	struct rte_ml_dev *dev;
318 	int ret;
319 
320 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
321 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
322 		return -EINVAL;
323 	}
324 
325 	dev = rte_ml_dev_pmd_get_dev(dev_id);
326 	if (*dev->dev_ops->dev_start == NULL)
327 		return -ENOTSUP;
328 
329 	if (dev->data->dev_started != 0) {
330 		RTE_MLDEV_LOG(ERR, "Device %d is already started", dev_id);
331 		return -EBUSY;
332 	}
333 
334 	ret = (*dev->dev_ops->dev_start)(dev);
335 	if (ret == 0)
336 		dev->data->dev_started = 1;
337 
338 	return ret;
339 }
340 
341 int
342 rte_ml_dev_stop(int16_t dev_id)
343 {
344 	struct rte_ml_dev *dev;
345 	int ret;
346 
347 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
348 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
349 		return -EINVAL;
350 	}
351 
352 	dev = rte_ml_dev_pmd_get_dev(dev_id);
353 	if (*dev->dev_ops->dev_stop == NULL)
354 		return -ENOTSUP;
355 
356 	if (dev->data->dev_started == 0) {
357 		RTE_MLDEV_LOG(ERR, "Device %d is not started", dev_id);
358 		return -EBUSY;
359 	}
360 
361 	ret = (*dev->dev_ops->dev_stop)(dev);
362 	if (ret == 0)
363 		dev->data->dev_started = 0;
364 
365 	return ret;
366 }
367 
368 int
369 rte_ml_dev_queue_pair_setup(int16_t dev_id, uint16_t queue_pair_id,
370 			    const struct rte_ml_dev_qp_conf *qp_conf, int socket_id)
371 {
372 	struct rte_ml_dev *dev;
373 
374 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
375 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
376 		return -EINVAL;
377 	}
378 
379 	dev = rte_ml_dev_pmd_get_dev(dev_id);
380 	if (*dev->dev_ops->dev_queue_pair_setup == NULL)
381 		return -ENOTSUP;
382 
383 	if (queue_pair_id >= dev->data->nb_queue_pairs) {
384 		RTE_MLDEV_LOG(ERR, "Invalid queue_pair_id = %d", queue_pair_id);
385 		return -EINVAL;
386 	}
387 
388 	if (qp_conf == NULL) {
389 		RTE_MLDEV_LOG(ERR, "Dev %d, qp_conf cannot be NULL", dev_id);
390 		return -EINVAL;
391 	}
392 
393 	if (dev->data->dev_started) {
394 		RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id);
395 		return -EBUSY;
396 	}
397 
398 	return (*dev->dev_ops->dev_queue_pair_setup)(dev, queue_pair_id, qp_conf, socket_id);
399 }
400 
401 int
402 rte_ml_dev_stats_get(int16_t dev_id, struct rte_ml_dev_stats *stats)
403 {
404 	struct rte_ml_dev *dev;
405 
406 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
407 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
408 		return -EINVAL;
409 	}
410 
411 	dev = rte_ml_dev_pmd_get_dev(dev_id);
412 	if (*dev->dev_ops->dev_stats_get == NULL)
413 		return -ENOTSUP;
414 
415 	if (stats == NULL) {
416 		RTE_MLDEV_LOG(ERR, "Dev %d, stats cannot be NULL", dev_id);
417 		return -EINVAL;
418 	}
419 	memset(stats, 0, sizeof(struct rte_ml_dev_stats));
420 
421 	return (*dev->dev_ops->dev_stats_get)(dev, stats);
422 }
423 
424 void
425 rte_ml_dev_stats_reset(int16_t dev_id)
426 {
427 	struct rte_ml_dev *dev;
428 
429 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
430 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
431 		return;
432 	}
433 
434 	dev = rte_ml_dev_pmd_get_dev(dev_id);
435 	if (*dev->dev_ops->dev_stats_reset == NULL)
436 		return;
437 
438 	(*dev->dev_ops->dev_stats_reset)(dev);
439 }
440 
441 int
442 rte_ml_dev_xstats_names_get(int16_t dev_id, enum rte_ml_dev_xstats_mode mode, int32_t model_id,
443 			    struct rte_ml_dev_xstats_map *xstats_map, uint32_t size)
444 {
445 	struct rte_ml_dev *dev;
446 
447 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
448 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
449 		return -EINVAL;
450 	}
451 
452 	dev = rte_ml_dev_pmd_get_dev(dev_id);
453 	if (*dev->dev_ops->dev_xstats_names_get == NULL)
454 		return -ENOTSUP;
455 
456 	return (*dev->dev_ops->dev_xstats_names_get)(dev, mode, model_id, xstats_map, size);
457 }
458 
459 int
460 rte_ml_dev_xstats_by_name_get(int16_t dev_id, const char *name, uint16_t *stat_id, uint64_t *value)
461 {
462 	struct rte_ml_dev *dev;
463 
464 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
465 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
466 		return -EINVAL;
467 	}
468 
469 	dev = rte_ml_dev_pmd_get_dev(dev_id);
470 	if (*dev->dev_ops->dev_xstats_by_name_get == NULL)
471 		return -ENOTSUP;
472 
473 	if (name == NULL) {
474 		RTE_MLDEV_LOG(ERR, "Dev %d, name cannot be NULL", dev_id);
475 		return -EINVAL;
476 	}
477 
478 	if (value == NULL) {
479 		RTE_MLDEV_LOG(ERR, "Dev %d, value cannot be NULL", dev_id);
480 		return -EINVAL;
481 	}
482 
483 	return (*dev->dev_ops->dev_xstats_by_name_get)(dev, name, stat_id, value);
484 }
485 
486 int
487 rte_ml_dev_xstats_get(int16_t dev_id, enum rte_ml_dev_xstats_mode mode, int32_t model_id,
488 		      const uint16_t stat_ids[], uint64_t values[], uint16_t nb_ids)
489 {
490 	struct rte_ml_dev *dev;
491 
492 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
493 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
494 		return -EINVAL;
495 	}
496 
497 	dev = rte_ml_dev_pmd_get_dev(dev_id);
498 	if (*dev->dev_ops->dev_xstats_get == NULL)
499 		return -ENOTSUP;
500 
501 	if (stat_ids == NULL) {
502 		RTE_MLDEV_LOG(ERR, "Dev %d, stat_ids cannot be NULL", dev_id);
503 		return -EINVAL;
504 	}
505 
506 	if (values == NULL) {
507 		RTE_MLDEV_LOG(ERR, "Dev %d, values cannot be NULL", dev_id);
508 		return -EINVAL;
509 	}
510 
511 	return (*dev->dev_ops->dev_xstats_get)(dev, mode, model_id, stat_ids, values, nb_ids);
512 }
513 
514 int
515 rte_ml_dev_xstats_reset(int16_t dev_id, enum rte_ml_dev_xstats_mode mode, int32_t model_id,
516 			const uint16_t stat_ids[], uint16_t nb_ids)
517 {
518 	struct rte_ml_dev *dev;
519 
520 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
521 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
522 		return -EINVAL;
523 	}
524 
525 	dev = rte_ml_dev_pmd_get_dev(dev_id);
526 	if (*dev->dev_ops->dev_xstats_reset == NULL)
527 		return -ENOTSUP;
528 
529 	return (*dev->dev_ops->dev_xstats_reset)(dev, mode, model_id, stat_ids, nb_ids);
530 }
531 
532 int
533 rte_ml_dev_dump(int16_t dev_id, FILE *fd)
534 {
535 	struct rte_ml_dev *dev;
536 
537 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
538 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
539 		return -EINVAL;
540 	}
541 
542 	dev = rte_ml_dev_pmd_get_dev(dev_id);
543 	if (*dev->dev_ops->dev_dump == NULL)
544 		return -ENOTSUP;
545 
546 	if (fd == NULL) {
547 		RTE_MLDEV_LOG(ERR, "Dev %d, file descriptor cannot be NULL", dev_id);
548 		return -EINVAL;
549 	}
550 
551 	return (*dev->dev_ops->dev_dump)(dev, fd);
552 }
553 
554 int
555 rte_ml_dev_selftest(int16_t dev_id)
556 {
557 	struct rte_ml_dev *dev;
558 
559 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
560 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
561 		return -EINVAL;
562 	}
563 
564 	dev = rte_ml_dev_pmd_get_dev(dev_id);
565 	if (*dev->dev_ops->dev_selftest == NULL)
566 		return -ENOTSUP;
567 
568 	return (*dev->dev_ops->dev_selftest)(dev);
569 }
570 
571 int
572 rte_ml_model_load(int16_t dev_id, struct rte_ml_model_params *params, uint16_t *model_id)
573 {
574 	struct rte_ml_dev *dev;
575 
576 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
577 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
578 		return -EINVAL;
579 	}
580 
581 	dev = rte_ml_dev_pmd_get_dev(dev_id);
582 	if (*dev->dev_ops->model_load == NULL)
583 		return -ENOTSUP;
584 
585 	if (params == NULL) {
586 		RTE_MLDEV_LOG(ERR, "Dev %d, params cannot be NULL", dev_id);
587 		return -EINVAL;
588 	}
589 
590 	if (model_id == NULL) {
591 		RTE_MLDEV_LOG(ERR, "Dev %d, model_id cannot be NULL", dev_id);
592 		return -EINVAL;
593 	}
594 
595 	return (*dev->dev_ops->model_load)(dev, params, model_id);
596 }
597 
598 int
599 rte_ml_model_unload(int16_t dev_id, uint16_t model_id)
600 {
601 	struct rte_ml_dev *dev;
602 
603 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
604 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
605 		return -EINVAL;
606 	}
607 
608 	dev = rte_ml_dev_pmd_get_dev(dev_id);
609 	if (*dev->dev_ops->model_unload == NULL)
610 		return -ENOTSUP;
611 
612 	return (*dev->dev_ops->model_unload)(dev, model_id);
613 }
614 
615 int
616 rte_ml_model_start(int16_t dev_id, uint16_t model_id)
617 {
618 	struct rte_ml_dev *dev;
619 
620 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
621 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
622 		return -EINVAL;
623 	}
624 
625 	dev = rte_ml_dev_pmd_get_dev(dev_id);
626 	if (*dev->dev_ops->model_start == NULL)
627 		return -ENOTSUP;
628 
629 	return (*dev->dev_ops->model_start)(dev, model_id);
630 }
631 
632 int
633 rte_ml_model_stop(int16_t dev_id, uint16_t model_id)
634 {
635 	struct rte_ml_dev *dev;
636 
637 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
638 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
639 		return -EINVAL;
640 	}
641 
642 	dev = rte_ml_dev_pmd_get_dev(dev_id);
643 	if (*dev->dev_ops->model_stop == NULL)
644 		return -ENOTSUP;
645 
646 	return (*dev->dev_ops->model_stop)(dev, model_id);
647 }
648 
649 int
650 rte_ml_model_info_get(int16_t dev_id, uint16_t model_id, struct rte_ml_model_info *model_info)
651 {
652 	struct rte_ml_dev *dev;
653 
654 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
655 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
656 		return -EINVAL;
657 	}
658 
659 	dev = rte_ml_dev_pmd_get_dev(dev_id);
660 	if (*dev->dev_ops->model_info_get == NULL)
661 		return -ENOTSUP;
662 
663 	if (model_info == NULL) {
664 		RTE_MLDEV_LOG(ERR, "Dev %d, model_id %u, model_info cannot be NULL", dev_id,
665 			      model_id);
666 		return -EINVAL;
667 	}
668 
669 	return (*dev->dev_ops->model_info_get)(dev, model_id, model_info);
670 }
671 
672 int
673 rte_ml_model_params_update(int16_t dev_id, uint16_t model_id, void *buffer)
674 {
675 	struct rte_ml_dev *dev;
676 
677 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
678 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
679 		return -EINVAL;
680 	}
681 
682 	dev = rte_ml_dev_pmd_get_dev(dev_id);
683 	if (*dev->dev_ops->model_params_update == NULL)
684 		return -ENOTSUP;
685 
686 	if (buffer == NULL) {
687 		RTE_MLDEV_LOG(ERR, "Dev %d, buffer cannot be NULL", dev_id);
688 		return -EINVAL;
689 	}
690 
691 	return (*dev->dev_ops->model_params_update)(dev, model_id, buffer);
692 }
693 
694 int
695 rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, struct rte_ml_buff_seg **dbuffer,
696 		   struct rte_ml_buff_seg **qbuffer)
697 {
698 	struct rte_ml_dev *dev;
699 
700 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
701 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
702 		return -EINVAL;
703 	}
704 
705 	dev = rte_ml_dev_pmd_get_dev(dev_id);
706 	if (*dev->dev_ops->io_quantize == NULL)
707 		return -ENOTSUP;
708 
709 	if (dbuffer == NULL) {
710 		RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL", dev_id);
711 		return -EINVAL;
712 	}
713 
714 	if (qbuffer == NULL) {
715 		RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL", dev_id);
716 		return -EINVAL;
717 	}
718 
719 	return (*dev->dev_ops->io_quantize)(dev, model_id, dbuffer, qbuffer);
720 }
721 
722 int
723 rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, struct rte_ml_buff_seg **qbuffer,
724 		     struct rte_ml_buff_seg **dbuffer)
725 {
726 	struct rte_ml_dev *dev;
727 
728 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
729 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
730 		return -EINVAL;
731 	}
732 
733 	dev = rte_ml_dev_pmd_get_dev(dev_id);
734 	if (*dev->dev_ops->io_dequantize == NULL)
735 		return -ENOTSUP;
736 
737 	if (qbuffer == NULL) {
738 		RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL", dev_id);
739 		return -EINVAL;
740 	}
741 
742 	if (dbuffer == NULL) {
743 		RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL", dev_id);
744 		return -EINVAL;
745 	}
746 
747 	return (*dev->dev_ops->io_dequantize)(dev, model_id, qbuffer, dbuffer);
748 }
749 
750 /** Initialise rte_ml_op mempool element */
751 static void
752 ml_op_init(struct rte_mempool *mempool, __rte_unused void *opaque_arg, void *_op_data,
753 	   __rte_unused unsigned int i)
754 {
755 	struct rte_ml_op *op = _op_data;
756 
757 	memset(_op_data, 0, mempool->elt_size);
758 	op->status = RTE_ML_OP_STATUS_NOT_PROCESSED;
759 	op->mempool = mempool;
760 }
761 
762 struct rte_mempool *
763 rte_ml_op_pool_create(const char *name, unsigned int nb_elts, unsigned int cache_size,
764 		      uint16_t user_size, int socket_id)
765 {
766 	struct rte_ml_op_pool_private *priv;
767 	struct rte_mempool *mp;
768 	unsigned int elt_size;
769 
770 	/* lookup mempool in case already allocated */
771 	mp = rte_mempool_lookup(name);
772 	elt_size = sizeof(struct rte_ml_op) + user_size;
773 
774 	if (mp != NULL) {
775 		priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp);
776 		if (mp->elt_size != elt_size || mp->cache_size < cache_size || mp->size < nb_elts ||
777 		    priv->user_size < user_size) {
778 			mp = NULL;
779 			RTE_MLDEV_LOG(ERR,
780 				      "Mempool %s already exists but with incompatible parameters",
781 				      name);
782 			return NULL;
783 		}
784 		return mp;
785 	}
786 
787 	mp = rte_mempool_create(name, nb_elts, elt_size, cache_size,
788 				sizeof(struct rte_ml_op_pool_private), NULL, NULL, ml_op_init, NULL,
789 				socket_id, 0);
790 	if (mp == NULL) {
791 		RTE_MLDEV_LOG(ERR, "Failed to create mempool %s", name);
792 		return NULL;
793 	}
794 
795 	priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp);
796 	priv->user_size = user_size;
797 
798 	return mp;
799 }
800 
801 void
802 rte_ml_op_pool_free(struct rte_mempool *mempool)
803 {
804 	rte_mempool_free(mempool);
805 }
806 
807 uint16_t
808 rte_ml_enqueue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops)
809 {
810 	struct rte_ml_dev *dev;
811 
812 #ifdef RTE_LIBRTE_ML_DEV_DEBUG
813 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
814 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
815 		rte_errno = -EINVAL;
816 		return 0;
817 	}
818 
819 	dev = rte_ml_dev_pmd_get_dev(dev_id);
820 	if (*dev->enqueue_burst == NULL) {
821 		rte_errno = -ENOTSUP;
822 		return 0;
823 	}
824 
825 	if (ops == NULL) {
826 		RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL", dev_id);
827 		rte_errno = -EINVAL;
828 		return 0;
829 	}
830 
831 	if (qp_id >= dev->data->nb_queue_pairs) {
832 		RTE_MLDEV_LOG(ERR, "Invalid qp_id %u", qp_id);
833 		rte_errno = -EINVAL;
834 		return 0;
835 	}
836 #else
837 	dev = rte_ml_dev_pmd_get_dev(dev_id);
838 #endif
839 
840 	return (*dev->enqueue_burst)(dev, qp_id, ops, nb_ops);
841 }
842 
843 uint16_t
844 rte_ml_dequeue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops)
845 {
846 	struct rte_ml_dev *dev;
847 
848 #ifdef RTE_LIBRTE_ML_DEV_DEBUG
849 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
850 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
851 		rte_errno = -EINVAL;
852 		return 0;
853 	}
854 
855 	dev = rte_ml_dev_pmd_get_dev(dev_id);
856 	if (*dev->dequeue_burst == NULL) {
857 		rte_errno = -ENOTSUP;
858 		return 0;
859 	}
860 
861 	if (ops == NULL) {
862 		RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL", dev_id);
863 		rte_errno = -EINVAL;
864 		return 0;
865 	}
866 
867 	if (qp_id >= dev->data->nb_queue_pairs) {
868 		RTE_MLDEV_LOG(ERR, "Invalid qp_id %u", qp_id);
869 		rte_errno = -EINVAL;
870 		return 0;
871 	}
872 #else
873 	dev = rte_ml_dev_pmd_get_dev(dev_id);
874 #endif
875 
876 	return (*dev->dequeue_burst)(dev, qp_id, ops, nb_ops);
877 }
878 
879 int
880 rte_ml_op_error_get(int16_t dev_id, struct rte_ml_op *op, struct rte_ml_op_error *error)
881 {
882 	struct rte_ml_dev *dev;
883 
884 #ifdef RTE_LIBRTE_ML_DEV_DEBUG
885 	if (!rte_ml_dev_is_valid_dev(dev_id)) {
886 		RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d", dev_id);
887 		return -EINVAL;
888 	}
889 
890 	dev = rte_ml_dev_pmd_get_dev(dev_id);
891 	if (*dev->op_error_get == NULL)
892 		return -ENOTSUP;
893 
894 	if (op == NULL) {
895 		RTE_MLDEV_LOG(ERR, "Dev %d, op cannot be NULL", dev_id);
896 		return -EINVAL;
897 	}
898 
899 	if (error == NULL) {
900 		RTE_MLDEV_LOG(ERR, "Dev %d, error cannot be NULL", dev_id);
901 		return -EINVAL;
902 	}
903 #else
904 	dev = rte_ml_dev_pmd_get_dev(dev_id);
905 #endif
906 
907 	return (*dev->op_error_get)(dev, op, error);
908 }
909 
910 RTE_LOG_REGISTER_DEFAULT(rte_ml_dev_logtype, INFO);
911