1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright (c) 2022 Marvell. 3 */ 4 5 #include <rte_errno.h> 6 #include <rte_log.h> 7 #include <rte_mldev.h> 8 #include <rte_mldev_pmd.h> 9 10 #include <stdlib.h> 11 12 static struct rte_ml_dev_global ml_dev_globals = { 13 .devs = NULL, .data = NULL, .nb_devs = 0, .max_devs = RTE_MLDEV_DEFAULT_MAX}; 14 15 /* 16 * Private data structure of an operation pool. 17 * 18 * A structure that contains ml op_pool specific data that is 19 * appended after the mempool structure (in private data). 20 */ 21 struct rte_ml_op_pool_private { 22 uint16_t user_size; 23 /*< Size of private user data with each operation. */ 24 }; 25 26 struct rte_ml_dev * 27 rte_ml_dev_pmd_get_dev(int16_t dev_id) 28 { 29 return &ml_dev_globals.devs[dev_id]; 30 } 31 32 struct rte_ml_dev * 33 rte_ml_dev_pmd_get_named_dev(const char *name) 34 { 35 struct rte_ml_dev *dev; 36 int16_t dev_id; 37 38 if (name == NULL) 39 return NULL; 40 41 for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) { 42 dev = rte_ml_dev_pmd_get_dev(dev_id); 43 if ((dev->attached == ML_DEV_ATTACHED) && (strcmp(dev->data->name, name) == 0)) 44 return dev; 45 } 46 47 return NULL; 48 } 49 50 struct rte_ml_dev * 51 rte_ml_dev_pmd_allocate(const char *name, uint8_t socket_id) 52 { 53 char mz_name[RTE_MEMZONE_NAMESIZE]; 54 const struct rte_memzone *mz; 55 struct rte_ml_dev *dev; 56 int16_t dev_id; 57 58 /* implicit initialization of library before adding first device */ 59 if (ml_dev_globals.devs == NULL) { 60 if (rte_ml_dev_init(RTE_MLDEV_DEFAULT_MAX) != 0) 61 return NULL; 62 } 63 64 if (rte_ml_dev_pmd_get_named_dev(name) != NULL) { 65 RTE_MLDEV_LOG(ERR, "ML device with name %s already allocated!", name); 66 return NULL; 67 } 68 69 /* Get a free device ID */ 70 for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) { 71 dev = rte_ml_dev_pmd_get_dev(dev_id); 72 if (dev->attached == ML_DEV_DETACHED) 73 break; 74 } 75 76 if (dev_id == ml_dev_globals.max_devs) { 77 RTE_MLDEV_LOG(ERR, "Reached maximum number of ML devices"); 78 return NULL; 79 } 80 81 if (dev->data == NULL) { 82 /* Reserve memzone name */ 83 sprintf(mz_name, "rte_ml_dev_data_%d", dev_id); 84 if (rte_eal_process_type() == RTE_PROC_PRIMARY) { 85 mz = rte_memzone_reserve(mz_name, sizeof(struct rte_ml_dev_data), socket_id, 86 0); 87 RTE_MLDEV_LOG(DEBUG, "PRIMARY: reserved memzone for %s (%p)", mz_name, mz); 88 } else { 89 mz = rte_memzone_lookup(mz_name); 90 RTE_MLDEV_LOG(DEBUG, "SECONDARY: looked up memzone for %s (%p)", mz_name, 91 mz); 92 } 93 94 if (mz == NULL) 95 return NULL; 96 97 ml_dev_globals.data[dev_id] = mz->addr; 98 if (rte_eal_process_type() == RTE_PROC_PRIMARY) 99 memset(ml_dev_globals.data[dev_id], 0, sizeof(struct rte_ml_dev_data)); 100 101 dev->data = ml_dev_globals.data[dev_id]; 102 if (rte_eal_process_type() == RTE_PROC_PRIMARY) { 103 strlcpy(dev->data->name, name, RTE_ML_STR_MAX); 104 dev->data->dev_id = dev_id; 105 dev->data->socket_id = socket_id; 106 dev->data->dev_started = 0; 107 RTE_MLDEV_LOG(DEBUG, "PRIMARY: init mldev data"); 108 } 109 110 RTE_MLDEV_LOG(DEBUG, "Data for %s: dev_id %d, socket %u", dev->data->name, 111 dev->data->dev_id, dev->data->socket_id); 112 113 dev->attached = ML_DEV_ATTACHED; 114 ml_dev_globals.nb_devs++; 115 } 116 117 dev->enqueue_burst = NULL; 118 dev->dequeue_burst = NULL; 119 120 return dev; 121 } 122 123 int 124 rte_ml_dev_pmd_release(struct rte_ml_dev *dev) 125 { 126 char mz_name[RTE_MEMZONE_NAMESIZE]; 127 const struct rte_memzone *mz; 128 int16_t dev_id; 129 int ret = 0; 130 131 if (dev == NULL) 132 return -EINVAL; 133 134 dev_id = dev->data->dev_id; 135 136 /* Memzone lookup */ 137 sprintf(mz_name, "rte_ml_dev_data_%d", dev_id); 138 mz = rte_memzone_lookup(mz_name); 139 if (mz == NULL) 140 return -ENOMEM; 141 142 RTE_ASSERT(ml_dev_globals.data[dev_id] == mz->addr); 143 ml_dev_globals.data[dev_id] = NULL; 144 145 if (rte_eal_process_type() == RTE_PROC_PRIMARY) { 146 RTE_MLDEV_LOG(DEBUG, "PRIMARY: free memzone of %s (%p)", mz_name, mz); 147 ret = rte_memzone_free(mz); 148 } else { 149 RTE_MLDEV_LOG(DEBUG, "SECONDARY: don't free memzone of %s (%p)", mz_name, mz); 150 } 151 152 dev->attached = ML_DEV_DETACHED; 153 ml_dev_globals.nb_devs--; 154 155 return ret; 156 } 157 158 int 159 rte_ml_dev_init(size_t dev_max) 160 { 161 if (dev_max == 0 || dev_max > INT16_MAX) { 162 RTE_MLDEV_LOG(ERR, "Invalid dev_max = %zu (> %d)\n", dev_max, INT16_MAX); 163 rte_errno = EINVAL; 164 return -rte_errno; 165 } 166 167 /* No lock, it must be called before or during first probing. */ 168 if (ml_dev_globals.devs != NULL) { 169 RTE_MLDEV_LOG(ERR, "Device array already initialized"); 170 rte_errno = EBUSY; 171 return -rte_errno; 172 } 173 174 ml_dev_globals.devs = calloc(dev_max, sizeof(struct rte_ml_dev)); 175 if (ml_dev_globals.devs == NULL) { 176 RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library"); 177 rte_errno = ENOMEM; 178 return -rte_errno; 179 } 180 181 ml_dev_globals.data = calloc(dev_max, sizeof(struct rte_ml_dev_data *)); 182 if (ml_dev_globals.data == NULL) { 183 RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library"); 184 rte_errno = ENOMEM; 185 return -rte_errno; 186 } 187 188 ml_dev_globals.max_devs = dev_max; 189 ml_dev_globals.devs = ml_dev_globals.devs; 190 191 return 0; 192 } 193 194 uint16_t 195 rte_ml_dev_count(void) 196 { 197 return ml_dev_globals.nb_devs; 198 } 199 200 int 201 rte_ml_dev_is_valid_dev(int16_t dev_id) 202 { 203 struct rte_ml_dev *dev = NULL; 204 205 if (dev_id >= ml_dev_globals.max_devs || ml_dev_globals.devs[dev_id].data == NULL) 206 return 0; 207 208 dev = rte_ml_dev_pmd_get_dev(dev_id); 209 if (dev->attached != ML_DEV_ATTACHED) 210 return 0; 211 else 212 return 1; 213 } 214 215 int 216 rte_ml_dev_socket_id(int16_t dev_id) 217 { 218 struct rte_ml_dev *dev; 219 220 if (!rte_ml_dev_is_valid_dev(dev_id)) { 221 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 222 return -EINVAL; 223 } 224 225 dev = rte_ml_dev_pmd_get_dev(dev_id); 226 227 return dev->data->socket_id; 228 } 229 230 int 231 rte_ml_dev_info_get(int16_t dev_id, struct rte_ml_dev_info *dev_info) 232 { 233 struct rte_ml_dev *dev; 234 235 if (!rte_ml_dev_is_valid_dev(dev_id)) { 236 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 237 return -EINVAL; 238 } 239 240 dev = rte_ml_dev_pmd_get_dev(dev_id); 241 if (*dev->dev_ops->dev_info_get == NULL) 242 return -ENOTSUP; 243 244 if (dev_info == NULL) { 245 RTE_MLDEV_LOG(ERR, "Dev %d, dev_info cannot be NULL\n", dev_id); 246 return -EINVAL; 247 } 248 memset(dev_info, 0, sizeof(struct rte_ml_dev_info)); 249 250 return (*dev->dev_ops->dev_info_get)(dev, dev_info); 251 } 252 253 int 254 rte_ml_dev_configure(int16_t dev_id, const struct rte_ml_dev_config *config) 255 { 256 struct rte_ml_dev_info dev_info; 257 struct rte_ml_dev *dev; 258 int ret; 259 260 if (!rte_ml_dev_is_valid_dev(dev_id)) { 261 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 262 return -EINVAL; 263 } 264 265 dev = rte_ml_dev_pmd_get_dev(dev_id); 266 if (*dev->dev_ops->dev_configure == NULL) 267 return -ENOTSUP; 268 269 if (dev->data->dev_started) { 270 RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id); 271 return -EBUSY; 272 } 273 274 if (config == NULL) { 275 RTE_MLDEV_LOG(ERR, "Dev %d, config cannot be NULL\n", dev_id); 276 return -EINVAL; 277 } 278 279 ret = rte_ml_dev_info_get(dev_id, &dev_info); 280 if (ret < 0) 281 return ret; 282 283 if (config->nb_queue_pairs > dev_info.max_queue_pairs) { 284 RTE_MLDEV_LOG(ERR, "Device %d num of queues %u > %u\n", dev_id, 285 config->nb_queue_pairs, dev_info.max_queue_pairs); 286 return -EINVAL; 287 } 288 289 return (*dev->dev_ops->dev_configure)(dev, config); 290 } 291 292 int 293 rte_ml_dev_close(int16_t dev_id) 294 { 295 struct rte_ml_dev *dev; 296 297 if (!rte_ml_dev_is_valid_dev(dev_id)) { 298 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 299 return -EINVAL; 300 } 301 302 dev = rte_ml_dev_pmd_get_dev(dev_id); 303 if (*dev->dev_ops->dev_close == NULL) 304 return -ENOTSUP; 305 306 /* Device must be stopped before it can be closed */ 307 if (dev->data->dev_started == 1) { 308 RTE_MLDEV_LOG(ERR, "Device %d must be stopped before closing", dev_id); 309 return -EBUSY; 310 } 311 312 return (*dev->dev_ops->dev_close)(dev); 313 } 314 315 int 316 rte_ml_dev_start(int16_t dev_id) 317 { 318 struct rte_ml_dev *dev; 319 int ret; 320 321 if (!rte_ml_dev_is_valid_dev(dev_id)) { 322 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 323 return -EINVAL; 324 } 325 326 dev = rte_ml_dev_pmd_get_dev(dev_id); 327 if (*dev->dev_ops->dev_start == NULL) 328 return -ENOTSUP; 329 330 if (dev->data->dev_started != 0) { 331 RTE_MLDEV_LOG(ERR, "Device %d is already started", dev_id); 332 return -EBUSY; 333 } 334 335 ret = (*dev->dev_ops->dev_start)(dev); 336 if (ret == 0) 337 dev->data->dev_started = 1; 338 339 return ret; 340 } 341 342 int 343 rte_ml_dev_stop(int16_t dev_id) 344 { 345 struct rte_ml_dev *dev; 346 int ret; 347 348 if (!rte_ml_dev_is_valid_dev(dev_id)) { 349 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 350 return -EINVAL; 351 } 352 353 dev = rte_ml_dev_pmd_get_dev(dev_id); 354 if (*dev->dev_ops->dev_stop == NULL) 355 return -ENOTSUP; 356 357 if (dev->data->dev_started == 0) { 358 RTE_MLDEV_LOG(ERR, "Device %d is not started", dev_id); 359 return -EBUSY; 360 } 361 362 ret = (*dev->dev_ops->dev_stop)(dev); 363 if (ret == 0) 364 dev->data->dev_started = 0; 365 366 return ret; 367 } 368 369 int 370 rte_ml_dev_queue_pair_setup(int16_t dev_id, uint16_t queue_pair_id, 371 const struct rte_ml_dev_qp_conf *qp_conf, int socket_id) 372 { 373 struct rte_ml_dev *dev; 374 375 if (!rte_ml_dev_is_valid_dev(dev_id)) { 376 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 377 return -EINVAL; 378 } 379 380 dev = rte_ml_dev_pmd_get_dev(dev_id); 381 if (*dev->dev_ops->dev_queue_pair_setup == NULL) 382 return -ENOTSUP; 383 384 if (queue_pair_id >= dev->data->nb_queue_pairs) { 385 RTE_MLDEV_LOG(ERR, "Invalid queue_pair_id = %d", queue_pair_id); 386 return -EINVAL; 387 } 388 389 if (qp_conf == NULL) { 390 RTE_MLDEV_LOG(ERR, "Dev %d, qp_conf cannot be NULL\n", dev_id); 391 return -EINVAL; 392 } 393 394 if (dev->data->dev_started) { 395 RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id); 396 return -EBUSY; 397 } 398 399 return (*dev->dev_ops->dev_queue_pair_setup)(dev, queue_pair_id, qp_conf, socket_id); 400 } 401 402 int 403 rte_ml_dev_stats_get(int16_t dev_id, struct rte_ml_dev_stats *stats) 404 { 405 struct rte_ml_dev *dev; 406 407 if (!rte_ml_dev_is_valid_dev(dev_id)) { 408 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 409 return -EINVAL; 410 } 411 412 dev = rte_ml_dev_pmd_get_dev(dev_id); 413 if (*dev->dev_ops->dev_stats_get == NULL) 414 return -ENOTSUP; 415 416 if (stats == NULL) { 417 RTE_MLDEV_LOG(ERR, "Dev %d, stats cannot be NULL\n", dev_id); 418 return -EINVAL; 419 } 420 memset(stats, 0, sizeof(struct rte_ml_dev_stats)); 421 422 return (*dev->dev_ops->dev_stats_get)(dev, stats); 423 } 424 425 void 426 rte_ml_dev_stats_reset(int16_t dev_id) 427 { 428 struct rte_ml_dev *dev; 429 430 if (!rte_ml_dev_is_valid_dev(dev_id)) { 431 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 432 return; 433 } 434 435 dev = rte_ml_dev_pmd_get_dev(dev_id); 436 if (*dev->dev_ops->dev_stats_reset == NULL) 437 return; 438 439 (*dev->dev_ops->dev_stats_reset)(dev); 440 } 441 442 int 443 rte_ml_dev_xstats_names_get(int16_t dev_id, struct rte_ml_dev_xstats_map *xstats_map, uint32_t size) 444 { 445 struct rte_ml_dev *dev; 446 447 if (!rte_ml_dev_is_valid_dev(dev_id)) { 448 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 449 return -EINVAL; 450 } 451 452 dev = rte_ml_dev_pmd_get_dev(dev_id); 453 if (*dev->dev_ops->dev_xstats_names_get == NULL) 454 return -ENOTSUP; 455 456 return (*dev->dev_ops->dev_xstats_names_get)(dev, xstats_map, size); 457 } 458 459 int 460 rte_ml_dev_xstats_by_name_get(int16_t dev_id, const char *name, uint16_t *stat_id, uint64_t *value) 461 { 462 struct rte_ml_dev *dev; 463 464 if (!rte_ml_dev_is_valid_dev(dev_id)) { 465 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 466 return -EINVAL; 467 } 468 469 dev = rte_ml_dev_pmd_get_dev(dev_id); 470 if (*dev->dev_ops->dev_xstats_by_name_get == NULL) 471 return -ENOTSUP; 472 473 if (name == NULL) { 474 RTE_MLDEV_LOG(ERR, "Dev %d, name cannot be NULL\n", dev_id); 475 return -EINVAL; 476 } 477 478 if (value == NULL) { 479 RTE_MLDEV_LOG(ERR, "Dev %d, value cannot be NULL\n", dev_id); 480 return -EINVAL; 481 } 482 483 return (*dev->dev_ops->dev_xstats_by_name_get)(dev, name, stat_id, value); 484 } 485 486 int 487 rte_ml_dev_xstats_get(int16_t dev_id, const uint16_t *stat_ids, uint64_t *values, uint16_t nb_ids) 488 { 489 struct rte_ml_dev *dev; 490 491 if (!rte_ml_dev_is_valid_dev(dev_id)) { 492 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 493 return -EINVAL; 494 } 495 496 dev = rte_ml_dev_pmd_get_dev(dev_id); 497 if (*dev->dev_ops->dev_xstats_get == NULL) 498 return -ENOTSUP; 499 500 if (stat_ids == NULL) { 501 RTE_MLDEV_LOG(ERR, "Dev %d, stat_ids cannot be NULL\n", dev_id); 502 return -EINVAL; 503 } 504 505 if (values == NULL) { 506 RTE_MLDEV_LOG(ERR, "Dev %d, values cannot be NULL\n", dev_id); 507 return -EINVAL; 508 } 509 510 return (*dev->dev_ops->dev_xstats_get)(dev, stat_ids, values, nb_ids); 511 } 512 513 int 514 rte_ml_dev_xstats_reset(int16_t dev_id, const uint16_t *stat_ids, uint16_t nb_ids) 515 { 516 struct rte_ml_dev *dev; 517 518 if (!rte_ml_dev_is_valid_dev(dev_id)) { 519 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 520 return -EINVAL; 521 } 522 523 dev = rte_ml_dev_pmd_get_dev(dev_id); 524 if (*dev->dev_ops->dev_xstats_reset == NULL) 525 return -ENOTSUP; 526 527 return (*dev->dev_ops->dev_xstats_reset)(dev, stat_ids, nb_ids); 528 } 529 530 int 531 rte_ml_dev_dump(int16_t dev_id, FILE *fd) 532 { 533 struct rte_ml_dev *dev; 534 535 if (!rte_ml_dev_is_valid_dev(dev_id)) { 536 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 537 return -EINVAL; 538 } 539 540 dev = rte_ml_dev_pmd_get_dev(dev_id); 541 if (*dev->dev_ops->dev_dump == NULL) 542 return -ENOTSUP; 543 544 if (fd == NULL) { 545 RTE_MLDEV_LOG(ERR, "Dev %d, file descriptor cannot be NULL\n", dev_id); 546 return -EINVAL; 547 } 548 549 return (*dev->dev_ops->dev_dump)(dev, fd); 550 } 551 552 int 553 rte_ml_dev_selftest(int16_t dev_id) 554 { 555 struct rte_ml_dev *dev; 556 557 if (!rte_ml_dev_is_valid_dev(dev_id)) { 558 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 559 return -EINVAL; 560 } 561 562 dev = rte_ml_dev_pmd_get_dev(dev_id); 563 if (*dev->dev_ops->dev_selftest == NULL) 564 return -ENOTSUP; 565 566 return (*dev->dev_ops->dev_selftest)(dev); 567 } 568 569 int 570 rte_ml_model_load(int16_t dev_id, struct rte_ml_model_params *params, uint16_t *model_id) 571 { 572 struct rte_ml_dev *dev; 573 574 if (!rte_ml_dev_is_valid_dev(dev_id)) { 575 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 576 return -EINVAL; 577 } 578 579 dev = rte_ml_dev_pmd_get_dev(dev_id); 580 if (*dev->dev_ops->model_load == NULL) 581 return -ENOTSUP; 582 583 if (params == NULL) { 584 RTE_MLDEV_LOG(ERR, "Dev %d, params cannot be NULL\n", dev_id); 585 return -EINVAL; 586 } 587 588 if (model_id == NULL) { 589 RTE_MLDEV_LOG(ERR, "Dev %d, model_id cannot be NULL\n", dev_id); 590 return -EINVAL; 591 } 592 593 return (*dev->dev_ops->model_load)(dev, params, model_id); 594 } 595 596 int 597 rte_ml_model_unload(int16_t dev_id, uint16_t model_id) 598 { 599 struct rte_ml_dev *dev; 600 601 if (!rte_ml_dev_is_valid_dev(dev_id)) { 602 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 603 return -EINVAL; 604 } 605 606 dev = rte_ml_dev_pmd_get_dev(dev_id); 607 if (*dev->dev_ops->model_unload == NULL) 608 return -ENOTSUP; 609 610 return (*dev->dev_ops->model_unload)(dev, model_id); 611 } 612 613 int 614 rte_ml_model_start(int16_t dev_id, uint16_t model_id) 615 { 616 struct rte_ml_dev *dev; 617 618 if (!rte_ml_dev_is_valid_dev(dev_id)) { 619 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 620 return -EINVAL; 621 } 622 623 dev = rte_ml_dev_pmd_get_dev(dev_id); 624 if (*dev->dev_ops->model_start == NULL) 625 return -ENOTSUP; 626 627 return (*dev->dev_ops->model_start)(dev, model_id); 628 } 629 630 int 631 rte_ml_model_stop(int16_t dev_id, uint16_t model_id) 632 { 633 struct rte_ml_dev *dev; 634 635 if (!rte_ml_dev_is_valid_dev(dev_id)) { 636 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 637 return -EINVAL; 638 } 639 640 dev = rte_ml_dev_pmd_get_dev(dev_id); 641 if (*dev->dev_ops->model_stop == NULL) 642 return -ENOTSUP; 643 644 return (*dev->dev_ops->model_stop)(dev, model_id); 645 } 646 647 int 648 rte_ml_model_info_get(int16_t dev_id, uint16_t model_id, struct rte_ml_model_info *model_info) 649 { 650 struct rte_ml_dev *dev; 651 652 if (!rte_ml_dev_is_valid_dev(dev_id)) { 653 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 654 return -EINVAL; 655 } 656 657 dev = rte_ml_dev_pmd_get_dev(dev_id); 658 if (*dev->dev_ops->model_info_get == NULL) 659 return -ENOTSUP; 660 661 if (model_info == NULL) { 662 RTE_MLDEV_LOG(ERR, "Dev %d, model_id %u, model_info cannot be NULL\n", dev_id, 663 model_id); 664 return -EINVAL; 665 } 666 667 return (*dev->dev_ops->model_info_get)(dev, model_id, model_info); 668 } 669 670 int 671 rte_ml_model_params_update(int16_t dev_id, uint16_t model_id, void *buffer) 672 { 673 struct rte_ml_dev *dev; 674 675 if (!rte_ml_dev_is_valid_dev(dev_id)) { 676 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 677 return -EINVAL; 678 } 679 680 dev = rte_ml_dev_pmd_get_dev(dev_id); 681 if (*dev->dev_ops->model_params_update == NULL) 682 return -ENOTSUP; 683 684 if (buffer == NULL) { 685 RTE_MLDEV_LOG(ERR, "Dev %d, buffer cannot be NULL\n", dev_id); 686 return -EINVAL; 687 } 688 689 return (*dev->dev_ops->model_params_update)(dev, model_id, buffer); 690 } 691 692 int 693 rte_ml_io_input_size_get(int16_t dev_id, uint16_t model_id, uint32_t nb_batches, 694 uint64_t *input_qsize, uint64_t *input_dsize) 695 { 696 struct rte_ml_dev *dev; 697 698 if (!rte_ml_dev_is_valid_dev(dev_id)) { 699 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 700 return -EINVAL; 701 } 702 703 dev = rte_ml_dev_pmd_get_dev(dev_id); 704 if (*dev->dev_ops->io_input_size_get == NULL) 705 return -ENOTSUP; 706 707 return (*dev->dev_ops->io_input_size_get)(dev, model_id, nb_batches, input_qsize, 708 input_dsize); 709 } 710 711 int 712 rte_ml_io_output_size_get(int16_t dev_id, uint16_t model_id, uint32_t nb_batches, 713 uint64_t *output_qsize, uint64_t *output_dsize) 714 { 715 struct rte_ml_dev *dev; 716 717 if (!rte_ml_dev_is_valid_dev(dev_id)) { 718 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 719 return -EINVAL; 720 } 721 722 dev = rte_ml_dev_pmd_get_dev(dev_id); 723 if (*dev->dev_ops->io_output_size_get == NULL) 724 return -ENOTSUP; 725 726 return (*dev->dev_ops->io_output_size_get)(dev, model_id, nb_batches, output_qsize, 727 output_dsize); 728 } 729 730 int 731 rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *dbuffer, 732 void *qbuffer) 733 { 734 struct rte_ml_dev *dev; 735 736 if (!rte_ml_dev_is_valid_dev(dev_id)) { 737 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 738 return -EINVAL; 739 } 740 741 dev = rte_ml_dev_pmd_get_dev(dev_id); 742 if (*dev->dev_ops->io_quantize == NULL) 743 return -ENOTSUP; 744 745 if (dbuffer == NULL) { 746 RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL\n", dev_id); 747 return -EINVAL; 748 } 749 750 if (qbuffer == NULL) { 751 RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL\n", dev_id); 752 return -EINVAL; 753 } 754 755 return (*dev->dev_ops->io_quantize)(dev, model_id, nb_batches, dbuffer, qbuffer); 756 } 757 758 int 759 rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *qbuffer, 760 void *dbuffer) 761 { 762 struct rte_ml_dev *dev; 763 764 if (!rte_ml_dev_is_valid_dev(dev_id)) { 765 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 766 return -EINVAL; 767 } 768 769 dev = rte_ml_dev_pmd_get_dev(dev_id); 770 if (*dev->dev_ops->io_dequantize == NULL) 771 return -ENOTSUP; 772 773 if (qbuffer == NULL) { 774 RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL\n", dev_id); 775 return -EINVAL; 776 } 777 778 if (dbuffer == NULL) { 779 RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL\n", dev_id); 780 return -EINVAL; 781 } 782 783 return (*dev->dev_ops->io_dequantize)(dev, model_id, nb_batches, qbuffer, dbuffer); 784 } 785 786 /** Initialise rte_ml_op mempool element */ 787 static void 788 ml_op_init(struct rte_mempool *mempool, __rte_unused void *opaque_arg, void *_op_data, 789 __rte_unused unsigned int i) 790 { 791 struct rte_ml_op *op = _op_data; 792 793 memset(_op_data, 0, mempool->elt_size); 794 op->status = RTE_ML_OP_STATUS_NOT_PROCESSED; 795 op->mempool = mempool; 796 } 797 798 struct rte_mempool * 799 rte_ml_op_pool_create(const char *name, unsigned int nb_elts, unsigned int cache_size, 800 uint16_t user_size, int socket_id) 801 { 802 struct rte_ml_op_pool_private *priv; 803 struct rte_mempool *mp; 804 unsigned int elt_size; 805 806 /* lookup mempool in case already allocated */ 807 mp = rte_mempool_lookup(name); 808 elt_size = sizeof(struct rte_ml_op) + user_size; 809 810 if (mp != NULL) { 811 priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp); 812 if (mp->elt_size != elt_size || mp->cache_size < cache_size || mp->size < nb_elts || 813 priv->user_size < user_size) { 814 mp = NULL; 815 RTE_MLDEV_LOG(ERR, 816 "Mempool %s already exists but with incompatible parameters", 817 name); 818 return NULL; 819 } 820 return mp; 821 } 822 823 mp = rte_mempool_create(name, nb_elts, elt_size, cache_size, 824 sizeof(struct rte_ml_op_pool_private), NULL, NULL, ml_op_init, NULL, 825 socket_id, 0); 826 if (mp == NULL) { 827 RTE_MLDEV_LOG(ERR, "Failed to create mempool %s", name); 828 return NULL; 829 } 830 831 priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp); 832 priv->user_size = user_size; 833 834 return mp; 835 } 836 837 void 838 rte_ml_op_pool_free(struct rte_mempool *mempool) 839 { 840 if (mempool != NULL) 841 rte_mempool_free(mempool); 842 } 843 844 uint16_t 845 rte_ml_enqueue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops) 846 { 847 struct rte_ml_dev *dev; 848 849 #ifdef RTE_LIBRTE_ML_DEV_DEBUG 850 if (!rte_ml_dev_is_valid_dev(dev_id)) { 851 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 852 rte_errno = -EINVAL; 853 return 0; 854 } 855 856 dev = rte_ml_dev_pmd_get_dev(dev_id); 857 if (*dev->enqueue_burst == NULL) { 858 rte_errno = -ENOTSUP; 859 return 0; 860 } 861 862 if (ops == NULL) { 863 RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL\n", dev_id); 864 rte_errno = -EINVAL; 865 return 0; 866 } 867 868 if (qp_id >= dev->data->nb_queue_pairs) { 869 RTE_MLDEV_LOG(ERR, "Invalid qp_id %u\n", qp_id); 870 rte_errno = -EINVAL; 871 return 0; 872 } 873 #else 874 dev = rte_ml_dev_pmd_get_dev(dev_id); 875 #endif 876 877 return (*dev->enqueue_burst)(dev, qp_id, ops, nb_ops); 878 } 879 880 uint16_t 881 rte_ml_dequeue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops) 882 { 883 struct rte_ml_dev *dev; 884 885 #ifdef RTE_LIBRTE_ML_DEV_DEBUG 886 if (!rte_ml_dev_is_valid_dev(dev_id)) { 887 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 888 rte_errno = -EINVAL; 889 return 0; 890 } 891 892 dev = rte_ml_dev_pmd_get_dev(dev_id); 893 if (*dev->dequeue_burst == NULL) { 894 rte_errno = -ENOTSUP; 895 return 0; 896 } 897 898 if (ops == NULL) { 899 RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL\n", dev_id); 900 rte_errno = -EINVAL; 901 return 0; 902 } 903 904 if (qp_id >= dev->data->nb_queue_pairs) { 905 RTE_MLDEV_LOG(ERR, "Invalid qp_id %u\n", qp_id); 906 rte_errno = -EINVAL; 907 return 0; 908 } 909 #else 910 dev = rte_ml_dev_pmd_get_dev(dev_id); 911 #endif 912 913 return (*dev->dequeue_burst)(dev, qp_id, ops, nb_ops); 914 } 915 916 int 917 rte_ml_op_error_get(int16_t dev_id, struct rte_ml_op *op, struct rte_ml_op_error *error) 918 { 919 struct rte_ml_dev *dev; 920 921 #ifdef RTE_LIBRTE_ML_DEV_DEBUG 922 if (!rte_ml_dev_is_valid_dev(dev_id)) { 923 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 924 return -EINVAL; 925 } 926 927 dev = rte_ml_dev_pmd_get_dev(dev_id); 928 if (*dev->op_error_get == NULL) 929 return -ENOTSUP; 930 931 if (op == NULL) { 932 RTE_MLDEV_LOG(ERR, "Dev %d, op cannot be NULL\n", dev_id); 933 return -EINVAL; 934 } 935 936 if (error == NULL) { 937 RTE_MLDEV_LOG(ERR, "Dev %d, error cannot be NULL\n", dev_id); 938 return -EINVAL; 939 } 940 #else 941 dev = rte_ml_dev_pmd_get_dev(dev_id); 942 #endif 943 944 return (*dev->op_error_get)(dev, op, error); 945 } 946 947 RTE_LOG_REGISTER_DEFAULT(rte_ml_dev_logtype, INFO); 948