1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright (c) 2022 Marvell. 3 */ 4 5 #include <rte_errno.h> 6 #include <rte_log.h> 7 #include <rte_mldev.h> 8 #include <rte_mldev_pmd.h> 9 10 #include <stdlib.h> 11 12 static struct rte_ml_dev_global ml_dev_globals = { 13 .devs = NULL, .data = NULL, .nb_devs = 0, .max_devs = RTE_MLDEV_DEFAULT_MAX}; 14 15 /* 16 * Private data structure of an operation pool. 17 * 18 * A structure that contains ml op_pool specific data that is 19 * appended after the mempool structure (in private data). 20 */ 21 struct rte_ml_op_pool_private { 22 uint16_t user_size; 23 /*< Size of private user data with each operation. */ 24 }; 25 26 struct rte_ml_dev * 27 rte_ml_dev_pmd_get_dev(int16_t dev_id) 28 { 29 return &ml_dev_globals.devs[dev_id]; 30 } 31 32 struct rte_ml_dev * 33 rte_ml_dev_pmd_get_named_dev(const char *name) 34 { 35 struct rte_ml_dev *dev; 36 int16_t dev_id; 37 38 if (name == NULL) 39 return NULL; 40 41 for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) { 42 dev = rte_ml_dev_pmd_get_dev(dev_id); 43 if ((dev->attached == ML_DEV_ATTACHED) && (strcmp(dev->data->name, name) == 0)) 44 return dev; 45 } 46 47 return NULL; 48 } 49 50 struct rte_ml_dev * 51 rte_ml_dev_pmd_allocate(const char *name, uint8_t socket_id) 52 { 53 char mz_name[RTE_MEMZONE_NAMESIZE]; 54 const struct rte_memzone *mz; 55 struct rte_ml_dev *dev; 56 int16_t dev_id; 57 58 /* implicit initialization of library before adding first device */ 59 if (ml_dev_globals.devs == NULL) { 60 if (rte_ml_dev_init(RTE_MLDEV_DEFAULT_MAX) != 0) 61 return NULL; 62 } 63 64 if (rte_ml_dev_pmd_get_named_dev(name) != NULL) { 65 RTE_MLDEV_LOG(ERR, "ML device with name %s already allocated!", name); 66 return NULL; 67 } 68 69 /* Get a free device ID */ 70 for (dev_id = 0; dev_id < ml_dev_globals.max_devs; dev_id++) { 71 dev = rte_ml_dev_pmd_get_dev(dev_id); 72 if (dev->attached == ML_DEV_DETACHED) 73 break; 74 } 75 76 if (dev_id == ml_dev_globals.max_devs) { 77 RTE_MLDEV_LOG(ERR, "Reached maximum number of ML devices"); 78 return NULL; 79 } 80 81 if (dev->data == NULL) { 82 /* Reserve memzone name */ 83 sprintf(mz_name, "rte_ml_dev_data_%d", dev_id); 84 if (rte_eal_process_type() == RTE_PROC_PRIMARY) { 85 mz = rte_memzone_reserve(mz_name, sizeof(struct rte_ml_dev_data), socket_id, 86 0); 87 RTE_MLDEV_LOG(DEBUG, "PRIMARY: reserved memzone for %s (%p)", mz_name, mz); 88 } else { 89 mz = rte_memzone_lookup(mz_name); 90 RTE_MLDEV_LOG(DEBUG, "SECONDARY: looked up memzone for %s (%p)", mz_name, 91 mz); 92 } 93 94 if (mz == NULL) 95 return NULL; 96 97 ml_dev_globals.data[dev_id] = mz->addr; 98 if (rte_eal_process_type() == RTE_PROC_PRIMARY) 99 memset(ml_dev_globals.data[dev_id], 0, sizeof(struct rte_ml_dev_data)); 100 101 dev->data = ml_dev_globals.data[dev_id]; 102 if (rte_eal_process_type() == RTE_PROC_PRIMARY) { 103 strlcpy(dev->data->name, name, RTE_ML_STR_MAX); 104 dev->data->dev_id = dev_id; 105 dev->data->socket_id = socket_id; 106 dev->data->dev_started = 0; 107 RTE_MLDEV_LOG(DEBUG, "PRIMARY: init mldev data"); 108 } 109 110 RTE_MLDEV_LOG(DEBUG, "Data for %s: dev_id %d, socket %u", dev->data->name, 111 dev->data->dev_id, dev->data->socket_id); 112 113 dev->attached = ML_DEV_ATTACHED; 114 ml_dev_globals.nb_devs++; 115 } 116 117 dev->enqueue_burst = NULL; 118 dev->dequeue_burst = NULL; 119 120 return dev; 121 } 122 123 int 124 rte_ml_dev_pmd_release(struct rte_ml_dev *dev) 125 { 126 char mz_name[RTE_MEMZONE_NAMESIZE]; 127 const struct rte_memzone *mz; 128 int16_t dev_id; 129 int ret = 0; 130 131 if (dev == NULL) 132 return -EINVAL; 133 134 dev_id = dev->data->dev_id; 135 136 /* Memzone lookup */ 137 sprintf(mz_name, "rte_ml_dev_data_%d", dev_id); 138 mz = rte_memzone_lookup(mz_name); 139 if (mz == NULL) 140 return -ENOMEM; 141 142 RTE_ASSERT(ml_dev_globals.data[dev_id] == mz->addr); 143 ml_dev_globals.data[dev_id] = NULL; 144 145 if (rte_eal_process_type() == RTE_PROC_PRIMARY) { 146 RTE_MLDEV_LOG(DEBUG, "PRIMARY: free memzone of %s (%p)", mz_name, mz); 147 ret = rte_memzone_free(mz); 148 } else { 149 RTE_MLDEV_LOG(DEBUG, "SECONDARY: don't free memzone of %s (%p)", mz_name, mz); 150 } 151 152 dev->attached = ML_DEV_DETACHED; 153 ml_dev_globals.nb_devs--; 154 155 return ret; 156 } 157 158 int 159 rte_ml_dev_init(size_t dev_max) 160 { 161 if (dev_max == 0 || dev_max > INT16_MAX) { 162 RTE_MLDEV_LOG(ERR, "Invalid dev_max = %zu (> %d)\n", dev_max, INT16_MAX); 163 rte_errno = EINVAL; 164 return -rte_errno; 165 } 166 167 /* No lock, it must be called before or during first probing. */ 168 if (ml_dev_globals.devs != NULL) { 169 RTE_MLDEV_LOG(ERR, "Device array already initialized"); 170 rte_errno = EBUSY; 171 return -rte_errno; 172 } 173 174 ml_dev_globals.devs = calloc(dev_max, sizeof(struct rte_ml_dev)); 175 if (ml_dev_globals.devs == NULL) { 176 RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library"); 177 rte_errno = ENOMEM; 178 return -rte_errno; 179 } 180 181 ml_dev_globals.data = calloc(dev_max, sizeof(struct rte_ml_dev_data *)); 182 if (ml_dev_globals.data == NULL) { 183 RTE_MLDEV_LOG(ERR, "Cannot initialize MLDEV library"); 184 rte_errno = ENOMEM; 185 return -rte_errno; 186 } 187 188 ml_dev_globals.max_devs = dev_max; 189 190 return 0; 191 } 192 193 uint16_t 194 rte_ml_dev_count(void) 195 { 196 return ml_dev_globals.nb_devs; 197 } 198 199 int 200 rte_ml_dev_is_valid_dev(int16_t dev_id) 201 { 202 struct rte_ml_dev *dev = NULL; 203 204 if (dev_id >= ml_dev_globals.max_devs || ml_dev_globals.devs[dev_id].data == NULL) 205 return 0; 206 207 dev = rte_ml_dev_pmd_get_dev(dev_id); 208 if (dev->attached != ML_DEV_ATTACHED) 209 return 0; 210 else 211 return 1; 212 } 213 214 int 215 rte_ml_dev_socket_id(int16_t dev_id) 216 { 217 struct rte_ml_dev *dev; 218 219 if (!rte_ml_dev_is_valid_dev(dev_id)) { 220 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 221 return -EINVAL; 222 } 223 224 dev = rte_ml_dev_pmd_get_dev(dev_id); 225 226 return dev->data->socket_id; 227 } 228 229 int 230 rte_ml_dev_info_get(int16_t dev_id, struct rte_ml_dev_info *dev_info) 231 { 232 struct rte_ml_dev *dev; 233 234 if (!rte_ml_dev_is_valid_dev(dev_id)) { 235 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 236 return -EINVAL; 237 } 238 239 dev = rte_ml_dev_pmd_get_dev(dev_id); 240 if (*dev->dev_ops->dev_info_get == NULL) 241 return -ENOTSUP; 242 243 if (dev_info == NULL) { 244 RTE_MLDEV_LOG(ERR, "Dev %d, dev_info cannot be NULL\n", dev_id); 245 return -EINVAL; 246 } 247 memset(dev_info, 0, sizeof(struct rte_ml_dev_info)); 248 249 return (*dev->dev_ops->dev_info_get)(dev, dev_info); 250 } 251 252 int 253 rte_ml_dev_configure(int16_t dev_id, const struct rte_ml_dev_config *config) 254 { 255 struct rte_ml_dev_info dev_info; 256 struct rte_ml_dev *dev; 257 int ret; 258 259 if (!rte_ml_dev_is_valid_dev(dev_id)) { 260 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 261 return -EINVAL; 262 } 263 264 dev = rte_ml_dev_pmd_get_dev(dev_id); 265 if (*dev->dev_ops->dev_configure == NULL) 266 return -ENOTSUP; 267 268 if (dev->data->dev_started) { 269 RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id); 270 return -EBUSY; 271 } 272 273 if (config == NULL) { 274 RTE_MLDEV_LOG(ERR, "Dev %d, config cannot be NULL\n", dev_id); 275 return -EINVAL; 276 } 277 278 ret = rte_ml_dev_info_get(dev_id, &dev_info); 279 if (ret < 0) 280 return ret; 281 282 if (config->nb_queue_pairs > dev_info.max_queue_pairs) { 283 RTE_MLDEV_LOG(ERR, "Device %d num of queues %u > %u\n", dev_id, 284 config->nb_queue_pairs, dev_info.max_queue_pairs); 285 return -EINVAL; 286 } 287 288 return (*dev->dev_ops->dev_configure)(dev, config); 289 } 290 291 int 292 rte_ml_dev_close(int16_t dev_id) 293 { 294 struct rte_ml_dev *dev; 295 296 if (!rte_ml_dev_is_valid_dev(dev_id)) { 297 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 298 return -EINVAL; 299 } 300 301 dev = rte_ml_dev_pmd_get_dev(dev_id); 302 if (*dev->dev_ops->dev_close == NULL) 303 return -ENOTSUP; 304 305 /* Device must be stopped before it can be closed */ 306 if (dev->data->dev_started == 1) { 307 RTE_MLDEV_LOG(ERR, "Device %d must be stopped before closing", dev_id); 308 return -EBUSY; 309 } 310 311 return (*dev->dev_ops->dev_close)(dev); 312 } 313 314 int 315 rte_ml_dev_start(int16_t dev_id) 316 { 317 struct rte_ml_dev *dev; 318 int ret; 319 320 if (!rte_ml_dev_is_valid_dev(dev_id)) { 321 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 322 return -EINVAL; 323 } 324 325 dev = rte_ml_dev_pmd_get_dev(dev_id); 326 if (*dev->dev_ops->dev_start == NULL) 327 return -ENOTSUP; 328 329 if (dev->data->dev_started != 0) { 330 RTE_MLDEV_LOG(ERR, "Device %d is already started", dev_id); 331 return -EBUSY; 332 } 333 334 ret = (*dev->dev_ops->dev_start)(dev); 335 if (ret == 0) 336 dev->data->dev_started = 1; 337 338 return ret; 339 } 340 341 int 342 rte_ml_dev_stop(int16_t dev_id) 343 { 344 struct rte_ml_dev *dev; 345 int ret; 346 347 if (!rte_ml_dev_is_valid_dev(dev_id)) { 348 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 349 return -EINVAL; 350 } 351 352 dev = rte_ml_dev_pmd_get_dev(dev_id); 353 if (*dev->dev_ops->dev_stop == NULL) 354 return -ENOTSUP; 355 356 if (dev->data->dev_started == 0) { 357 RTE_MLDEV_LOG(ERR, "Device %d is not started", dev_id); 358 return -EBUSY; 359 } 360 361 ret = (*dev->dev_ops->dev_stop)(dev); 362 if (ret == 0) 363 dev->data->dev_started = 0; 364 365 return ret; 366 } 367 368 int 369 rte_ml_dev_queue_pair_setup(int16_t dev_id, uint16_t queue_pair_id, 370 const struct rte_ml_dev_qp_conf *qp_conf, int socket_id) 371 { 372 struct rte_ml_dev *dev; 373 374 if (!rte_ml_dev_is_valid_dev(dev_id)) { 375 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 376 return -EINVAL; 377 } 378 379 dev = rte_ml_dev_pmd_get_dev(dev_id); 380 if (*dev->dev_ops->dev_queue_pair_setup == NULL) 381 return -ENOTSUP; 382 383 if (queue_pair_id >= dev->data->nb_queue_pairs) { 384 RTE_MLDEV_LOG(ERR, "Invalid queue_pair_id = %d", queue_pair_id); 385 return -EINVAL; 386 } 387 388 if (qp_conf == NULL) { 389 RTE_MLDEV_LOG(ERR, "Dev %d, qp_conf cannot be NULL\n", dev_id); 390 return -EINVAL; 391 } 392 393 if (dev->data->dev_started) { 394 RTE_MLDEV_LOG(ERR, "Device %d must be stopped to allow configuration", dev_id); 395 return -EBUSY; 396 } 397 398 return (*dev->dev_ops->dev_queue_pair_setup)(dev, queue_pair_id, qp_conf, socket_id); 399 } 400 401 int 402 rte_ml_dev_stats_get(int16_t dev_id, struct rte_ml_dev_stats *stats) 403 { 404 struct rte_ml_dev *dev; 405 406 if (!rte_ml_dev_is_valid_dev(dev_id)) { 407 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 408 return -EINVAL; 409 } 410 411 dev = rte_ml_dev_pmd_get_dev(dev_id); 412 if (*dev->dev_ops->dev_stats_get == NULL) 413 return -ENOTSUP; 414 415 if (stats == NULL) { 416 RTE_MLDEV_LOG(ERR, "Dev %d, stats cannot be NULL\n", dev_id); 417 return -EINVAL; 418 } 419 memset(stats, 0, sizeof(struct rte_ml_dev_stats)); 420 421 return (*dev->dev_ops->dev_stats_get)(dev, stats); 422 } 423 424 void 425 rte_ml_dev_stats_reset(int16_t dev_id) 426 { 427 struct rte_ml_dev *dev; 428 429 if (!rte_ml_dev_is_valid_dev(dev_id)) { 430 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 431 return; 432 } 433 434 dev = rte_ml_dev_pmd_get_dev(dev_id); 435 if (*dev->dev_ops->dev_stats_reset == NULL) 436 return; 437 438 (*dev->dev_ops->dev_stats_reset)(dev); 439 } 440 441 int 442 rte_ml_dev_xstats_names_get(int16_t dev_id, struct rte_ml_dev_xstats_map *xstats_map, uint32_t size) 443 { 444 struct rte_ml_dev *dev; 445 446 if (!rte_ml_dev_is_valid_dev(dev_id)) { 447 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 448 return -EINVAL; 449 } 450 451 dev = rte_ml_dev_pmd_get_dev(dev_id); 452 if (*dev->dev_ops->dev_xstats_names_get == NULL) 453 return -ENOTSUP; 454 455 return (*dev->dev_ops->dev_xstats_names_get)(dev, xstats_map, size); 456 } 457 458 int 459 rte_ml_dev_xstats_by_name_get(int16_t dev_id, const char *name, uint16_t *stat_id, uint64_t *value) 460 { 461 struct rte_ml_dev *dev; 462 463 if (!rte_ml_dev_is_valid_dev(dev_id)) { 464 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 465 return -EINVAL; 466 } 467 468 dev = rte_ml_dev_pmd_get_dev(dev_id); 469 if (*dev->dev_ops->dev_xstats_by_name_get == NULL) 470 return -ENOTSUP; 471 472 if (name == NULL) { 473 RTE_MLDEV_LOG(ERR, "Dev %d, name cannot be NULL\n", dev_id); 474 return -EINVAL; 475 } 476 477 if (value == NULL) { 478 RTE_MLDEV_LOG(ERR, "Dev %d, value cannot be NULL\n", dev_id); 479 return -EINVAL; 480 } 481 482 return (*dev->dev_ops->dev_xstats_by_name_get)(dev, name, stat_id, value); 483 } 484 485 int 486 rte_ml_dev_xstats_get(int16_t dev_id, const uint16_t *stat_ids, uint64_t *values, uint16_t nb_ids) 487 { 488 struct rte_ml_dev *dev; 489 490 if (!rte_ml_dev_is_valid_dev(dev_id)) { 491 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 492 return -EINVAL; 493 } 494 495 dev = rte_ml_dev_pmd_get_dev(dev_id); 496 if (*dev->dev_ops->dev_xstats_get == NULL) 497 return -ENOTSUP; 498 499 if (stat_ids == NULL) { 500 RTE_MLDEV_LOG(ERR, "Dev %d, stat_ids cannot be NULL\n", dev_id); 501 return -EINVAL; 502 } 503 504 if (values == NULL) { 505 RTE_MLDEV_LOG(ERR, "Dev %d, values cannot be NULL\n", dev_id); 506 return -EINVAL; 507 } 508 509 return (*dev->dev_ops->dev_xstats_get)(dev, stat_ids, values, nb_ids); 510 } 511 512 int 513 rte_ml_dev_xstats_reset(int16_t dev_id, const uint16_t *stat_ids, uint16_t nb_ids) 514 { 515 struct rte_ml_dev *dev; 516 517 if (!rte_ml_dev_is_valid_dev(dev_id)) { 518 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 519 return -EINVAL; 520 } 521 522 dev = rte_ml_dev_pmd_get_dev(dev_id); 523 if (*dev->dev_ops->dev_xstats_reset == NULL) 524 return -ENOTSUP; 525 526 return (*dev->dev_ops->dev_xstats_reset)(dev, stat_ids, nb_ids); 527 } 528 529 int 530 rte_ml_dev_dump(int16_t dev_id, FILE *fd) 531 { 532 struct rte_ml_dev *dev; 533 534 if (!rte_ml_dev_is_valid_dev(dev_id)) { 535 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 536 return -EINVAL; 537 } 538 539 dev = rte_ml_dev_pmd_get_dev(dev_id); 540 if (*dev->dev_ops->dev_dump == NULL) 541 return -ENOTSUP; 542 543 if (fd == NULL) { 544 RTE_MLDEV_LOG(ERR, "Dev %d, file descriptor cannot be NULL\n", dev_id); 545 return -EINVAL; 546 } 547 548 return (*dev->dev_ops->dev_dump)(dev, fd); 549 } 550 551 int 552 rte_ml_dev_selftest(int16_t dev_id) 553 { 554 struct rte_ml_dev *dev; 555 556 if (!rte_ml_dev_is_valid_dev(dev_id)) { 557 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 558 return -EINVAL; 559 } 560 561 dev = rte_ml_dev_pmd_get_dev(dev_id); 562 if (*dev->dev_ops->dev_selftest == NULL) 563 return -ENOTSUP; 564 565 return (*dev->dev_ops->dev_selftest)(dev); 566 } 567 568 int 569 rte_ml_model_load(int16_t dev_id, struct rte_ml_model_params *params, uint16_t *model_id) 570 { 571 struct rte_ml_dev *dev; 572 573 if (!rte_ml_dev_is_valid_dev(dev_id)) { 574 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 575 return -EINVAL; 576 } 577 578 dev = rte_ml_dev_pmd_get_dev(dev_id); 579 if (*dev->dev_ops->model_load == NULL) 580 return -ENOTSUP; 581 582 if (params == NULL) { 583 RTE_MLDEV_LOG(ERR, "Dev %d, params cannot be NULL\n", dev_id); 584 return -EINVAL; 585 } 586 587 if (model_id == NULL) { 588 RTE_MLDEV_LOG(ERR, "Dev %d, model_id cannot be NULL\n", dev_id); 589 return -EINVAL; 590 } 591 592 return (*dev->dev_ops->model_load)(dev, params, model_id); 593 } 594 595 int 596 rte_ml_model_unload(int16_t dev_id, uint16_t model_id) 597 { 598 struct rte_ml_dev *dev; 599 600 if (!rte_ml_dev_is_valid_dev(dev_id)) { 601 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 602 return -EINVAL; 603 } 604 605 dev = rte_ml_dev_pmd_get_dev(dev_id); 606 if (*dev->dev_ops->model_unload == NULL) 607 return -ENOTSUP; 608 609 return (*dev->dev_ops->model_unload)(dev, model_id); 610 } 611 612 int 613 rte_ml_model_start(int16_t dev_id, uint16_t model_id) 614 { 615 struct rte_ml_dev *dev; 616 617 if (!rte_ml_dev_is_valid_dev(dev_id)) { 618 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 619 return -EINVAL; 620 } 621 622 dev = rte_ml_dev_pmd_get_dev(dev_id); 623 if (*dev->dev_ops->model_start == NULL) 624 return -ENOTSUP; 625 626 return (*dev->dev_ops->model_start)(dev, model_id); 627 } 628 629 int 630 rte_ml_model_stop(int16_t dev_id, uint16_t model_id) 631 { 632 struct rte_ml_dev *dev; 633 634 if (!rte_ml_dev_is_valid_dev(dev_id)) { 635 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 636 return -EINVAL; 637 } 638 639 dev = rte_ml_dev_pmd_get_dev(dev_id); 640 if (*dev->dev_ops->model_stop == NULL) 641 return -ENOTSUP; 642 643 return (*dev->dev_ops->model_stop)(dev, model_id); 644 } 645 646 int 647 rte_ml_model_info_get(int16_t dev_id, uint16_t model_id, struct rte_ml_model_info *model_info) 648 { 649 struct rte_ml_dev *dev; 650 651 if (!rte_ml_dev_is_valid_dev(dev_id)) { 652 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 653 return -EINVAL; 654 } 655 656 dev = rte_ml_dev_pmd_get_dev(dev_id); 657 if (*dev->dev_ops->model_info_get == NULL) 658 return -ENOTSUP; 659 660 if (model_info == NULL) { 661 RTE_MLDEV_LOG(ERR, "Dev %d, model_id %u, model_info cannot be NULL\n", dev_id, 662 model_id); 663 return -EINVAL; 664 } 665 666 return (*dev->dev_ops->model_info_get)(dev, model_id, model_info); 667 } 668 669 int 670 rte_ml_model_params_update(int16_t dev_id, uint16_t model_id, void *buffer) 671 { 672 struct rte_ml_dev *dev; 673 674 if (!rte_ml_dev_is_valid_dev(dev_id)) { 675 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 676 return -EINVAL; 677 } 678 679 dev = rte_ml_dev_pmd_get_dev(dev_id); 680 if (*dev->dev_ops->model_params_update == NULL) 681 return -ENOTSUP; 682 683 if (buffer == NULL) { 684 RTE_MLDEV_LOG(ERR, "Dev %d, buffer cannot be NULL\n", dev_id); 685 return -EINVAL; 686 } 687 688 return (*dev->dev_ops->model_params_update)(dev, model_id, buffer); 689 } 690 691 int 692 rte_ml_io_input_size_get(int16_t dev_id, uint16_t model_id, uint32_t nb_batches, 693 uint64_t *input_qsize, uint64_t *input_dsize) 694 { 695 struct rte_ml_dev *dev; 696 697 if (!rte_ml_dev_is_valid_dev(dev_id)) { 698 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 699 return -EINVAL; 700 } 701 702 dev = rte_ml_dev_pmd_get_dev(dev_id); 703 if (*dev->dev_ops->io_input_size_get == NULL) 704 return -ENOTSUP; 705 706 return (*dev->dev_ops->io_input_size_get)(dev, model_id, nb_batches, input_qsize, 707 input_dsize); 708 } 709 710 int 711 rte_ml_io_output_size_get(int16_t dev_id, uint16_t model_id, uint32_t nb_batches, 712 uint64_t *output_qsize, uint64_t *output_dsize) 713 { 714 struct rte_ml_dev *dev; 715 716 if (!rte_ml_dev_is_valid_dev(dev_id)) { 717 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 718 return -EINVAL; 719 } 720 721 dev = rte_ml_dev_pmd_get_dev(dev_id); 722 if (*dev->dev_ops->io_output_size_get == NULL) 723 return -ENOTSUP; 724 725 return (*dev->dev_ops->io_output_size_get)(dev, model_id, nb_batches, output_qsize, 726 output_dsize); 727 } 728 729 int 730 rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *dbuffer, 731 void *qbuffer) 732 { 733 struct rte_ml_dev *dev; 734 735 if (!rte_ml_dev_is_valid_dev(dev_id)) { 736 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 737 return -EINVAL; 738 } 739 740 dev = rte_ml_dev_pmd_get_dev(dev_id); 741 if (*dev->dev_ops->io_quantize == NULL) 742 return -ENOTSUP; 743 744 if (dbuffer == NULL) { 745 RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL\n", dev_id); 746 return -EINVAL; 747 } 748 749 if (qbuffer == NULL) { 750 RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL\n", dev_id); 751 return -EINVAL; 752 } 753 754 return (*dev->dev_ops->io_quantize)(dev, model_id, nb_batches, dbuffer, qbuffer); 755 } 756 757 int 758 rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *qbuffer, 759 void *dbuffer) 760 { 761 struct rte_ml_dev *dev; 762 763 if (!rte_ml_dev_is_valid_dev(dev_id)) { 764 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 765 return -EINVAL; 766 } 767 768 dev = rte_ml_dev_pmd_get_dev(dev_id); 769 if (*dev->dev_ops->io_dequantize == NULL) 770 return -ENOTSUP; 771 772 if (qbuffer == NULL) { 773 RTE_MLDEV_LOG(ERR, "Dev %d, qbuffer cannot be NULL\n", dev_id); 774 return -EINVAL; 775 } 776 777 if (dbuffer == NULL) { 778 RTE_MLDEV_LOG(ERR, "Dev %d, dbuffer cannot be NULL\n", dev_id); 779 return -EINVAL; 780 } 781 782 return (*dev->dev_ops->io_dequantize)(dev, model_id, nb_batches, qbuffer, dbuffer); 783 } 784 785 /** Initialise rte_ml_op mempool element */ 786 static void 787 ml_op_init(struct rte_mempool *mempool, __rte_unused void *opaque_arg, void *_op_data, 788 __rte_unused unsigned int i) 789 { 790 struct rte_ml_op *op = _op_data; 791 792 memset(_op_data, 0, mempool->elt_size); 793 op->status = RTE_ML_OP_STATUS_NOT_PROCESSED; 794 op->mempool = mempool; 795 } 796 797 struct rte_mempool * 798 rte_ml_op_pool_create(const char *name, unsigned int nb_elts, unsigned int cache_size, 799 uint16_t user_size, int socket_id) 800 { 801 struct rte_ml_op_pool_private *priv; 802 struct rte_mempool *mp; 803 unsigned int elt_size; 804 805 /* lookup mempool in case already allocated */ 806 mp = rte_mempool_lookup(name); 807 elt_size = sizeof(struct rte_ml_op) + user_size; 808 809 if (mp != NULL) { 810 priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp); 811 if (mp->elt_size != elt_size || mp->cache_size < cache_size || mp->size < nb_elts || 812 priv->user_size < user_size) { 813 mp = NULL; 814 RTE_MLDEV_LOG(ERR, 815 "Mempool %s already exists but with incompatible parameters", 816 name); 817 return NULL; 818 } 819 return mp; 820 } 821 822 mp = rte_mempool_create(name, nb_elts, elt_size, cache_size, 823 sizeof(struct rte_ml_op_pool_private), NULL, NULL, ml_op_init, NULL, 824 socket_id, 0); 825 if (mp == NULL) { 826 RTE_MLDEV_LOG(ERR, "Failed to create mempool %s", name); 827 return NULL; 828 } 829 830 priv = (struct rte_ml_op_pool_private *)rte_mempool_get_priv(mp); 831 priv->user_size = user_size; 832 833 return mp; 834 } 835 836 void 837 rte_ml_op_pool_free(struct rte_mempool *mempool) 838 { 839 if (mempool != NULL) 840 rte_mempool_free(mempool); 841 } 842 843 uint16_t 844 rte_ml_enqueue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops) 845 { 846 struct rte_ml_dev *dev; 847 848 #ifdef RTE_LIBRTE_ML_DEV_DEBUG 849 if (!rte_ml_dev_is_valid_dev(dev_id)) { 850 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 851 rte_errno = -EINVAL; 852 return 0; 853 } 854 855 dev = rte_ml_dev_pmd_get_dev(dev_id); 856 if (*dev->enqueue_burst == NULL) { 857 rte_errno = -ENOTSUP; 858 return 0; 859 } 860 861 if (ops == NULL) { 862 RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL\n", dev_id); 863 rte_errno = -EINVAL; 864 return 0; 865 } 866 867 if (qp_id >= dev->data->nb_queue_pairs) { 868 RTE_MLDEV_LOG(ERR, "Invalid qp_id %u\n", qp_id); 869 rte_errno = -EINVAL; 870 return 0; 871 } 872 #else 873 dev = rte_ml_dev_pmd_get_dev(dev_id); 874 #endif 875 876 return (*dev->enqueue_burst)(dev, qp_id, ops, nb_ops); 877 } 878 879 uint16_t 880 rte_ml_dequeue_burst(int16_t dev_id, uint16_t qp_id, struct rte_ml_op **ops, uint16_t nb_ops) 881 { 882 struct rte_ml_dev *dev; 883 884 #ifdef RTE_LIBRTE_ML_DEV_DEBUG 885 if (!rte_ml_dev_is_valid_dev(dev_id)) { 886 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 887 rte_errno = -EINVAL; 888 return 0; 889 } 890 891 dev = rte_ml_dev_pmd_get_dev(dev_id); 892 if (*dev->dequeue_burst == NULL) { 893 rte_errno = -ENOTSUP; 894 return 0; 895 } 896 897 if (ops == NULL) { 898 RTE_MLDEV_LOG(ERR, "Dev %d, ops cannot be NULL\n", dev_id); 899 rte_errno = -EINVAL; 900 return 0; 901 } 902 903 if (qp_id >= dev->data->nb_queue_pairs) { 904 RTE_MLDEV_LOG(ERR, "Invalid qp_id %u\n", qp_id); 905 rte_errno = -EINVAL; 906 return 0; 907 } 908 #else 909 dev = rte_ml_dev_pmd_get_dev(dev_id); 910 #endif 911 912 return (*dev->dequeue_burst)(dev, qp_id, ops, nb_ops); 913 } 914 915 int 916 rte_ml_op_error_get(int16_t dev_id, struct rte_ml_op *op, struct rte_ml_op_error *error) 917 { 918 struct rte_ml_dev *dev; 919 920 #ifdef RTE_LIBRTE_ML_DEV_DEBUG 921 if (!rte_ml_dev_is_valid_dev(dev_id)) { 922 RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id); 923 return -EINVAL; 924 } 925 926 dev = rte_ml_dev_pmd_get_dev(dev_id); 927 if (*dev->op_error_get == NULL) 928 return -ENOTSUP; 929 930 if (op == NULL) { 931 RTE_MLDEV_LOG(ERR, "Dev %d, op cannot be NULL\n", dev_id); 932 return -EINVAL; 933 } 934 935 if (error == NULL) { 936 RTE_MLDEV_LOG(ERR, "Dev %d, error cannot be NULL\n", dev_id); 937 return -EINVAL; 938 } 939 #else 940 dev = rte_ml_dev_pmd_get_dev(dev_id); 941 #endif 942 943 return (*dev->op_error_get)(dev, op, error); 944 } 945 946 RTE_LOG_REGISTER_DEFAULT(rte_ml_dev_logtype, INFO); 947