1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright (c) 2023 Marvell. 3 */ 4 5 #include <rte_kvargs.h> 6 #include <rte_mldev.h> 7 #include <rte_mldev_pmd.h> 8 9 #include <bus_vdev_driver.h> 10 11 #include <roc_api.h> 12 13 #include "cnxk_ml_dev.h" 14 15 #define MVTVM_ML_DEV_MAX_QPS "max_qps" 16 #define MVTVM_ML_DEV_CACHE_MODEL_DATA "cache_model_data" 17 18 #define MVTVM_ML_DEV_MAX_QPS_DEFAULT 32 19 #define CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT 1 20 21 static const char *const valid_args[] = {MVTVM_ML_DEV_MAX_QPS, MVTVM_ML_DEV_CACHE_MODEL_DATA, NULL}; 22 23 static int 24 parse_integer_arg(const char *key __rte_unused, const char *value, void *extra_args) 25 { 26 int *i = (int *)extra_args; 27 28 *i = atoi(value); 29 if (*i < 0) { 30 plt_err("Argument has to be positive."); 31 return -EINVAL; 32 } 33 34 return 0; 35 } 36 37 static int 38 parse_uint_arg(const char *key __rte_unused, const char *value, void *extra_args) 39 { 40 int i; 41 char *end; 42 errno = 0; 43 44 i = strtol(value, &end, 10); 45 if (*end != 0 || errno != 0 || i < 0) 46 return -EINVAL; 47 48 *((uint32_t *)extra_args) = i; 49 50 return 0; 51 } 52 53 static int 54 mvtvm_mldev_parse_devargs(const char *args, struct mvtvm_ml_dev *mvtvm_mldev) 55 { 56 bool cache_model_data_set = false; 57 struct rte_kvargs *kvlist = NULL; 58 bool max_qps_set = false; 59 int ret = 0; 60 61 if (args == NULL) 62 goto check_args; 63 64 kvlist = rte_kvargs_parse(args, valid_args); 65 if (kvlist == NULL) { 66 plt_err("Error parsing %s devargs", "MLDEV_NAME_MVTVM_PMD"); 67 return -EINVAL; 68 } 69 70 if (rte_kvargs_count(kvlist, MVTVM_ML_DEV_MAX_QPS) == 1) { 71 ret = rte_kvargs_process(kvlist, MVTVM_ML_DEV_MAX_QPS, &parse_uint_arg, 72 &mvtvm_mldev->max_nb_qpairs); 73 if (ret < 0) { 74 plt_err("Error processing arguments, key = %s", MVTVM_ML_DEV_MAX_QPS); 75 ret = -EINVAL; 76 goto exit; 77 } 78 max_qps_set = true; 79 } 80 81 if (rte_kvargs_count(kvlist, MVTVM_ML_DEV_CACHE_MODEL_DATA) == 1) { 82 ret = rte_kvargs_process(kvlist, MVTVM_ML_DEV_CACHE_MODEL_DATA, &parse_integer_arg, 83 &mvtvm_mldev->cache_model_data); 84 if (ret < 0) { 85 plt_err("Error processing arguments, key = %s", 86 MVTVM_ML_DEV_CACHE_MODEL_DATA); 87 ret = -EINVAL; 88 goto exit; 89 } 90 cache_model_data_set = true; 91 } 92 93 check_args: 94 if (!max_qps_set) 95 mvtvm_mldev->max_nb_qpairs = MVTVM_ML_DEV_MAX_QPS_DEFAULT; 96 plt_ml_dbg("ML: %s = %u", MVTVM_ML_DEV_MAX_QPS, mvtvm_mldev->max_nb_qpairs); 97 98 if (!cache_model_data_set) { 99 mvtvm_mldev->cache_model_data = CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT; 100 } else { 101 if ((mvtvm_mldev->cache_model_data < 0) || (mvtvm_mldev->cache_model_data > 1)) { 102 plt_err("Invalid argument, %s = %d", MVTVM_ML_DEV_CACHE_MODEL_DATA, 103 mvtvm_mldev->cache_model_data); 104 ret = -EINVAL; 105 goto exit; 106 } 107 } 108 plt_ml_dbg("ML: %s = %d", MVTVM_ML_DEV_CACHE_MODEL_DATA, mvtvm_mldev->cache_model_data); 109 110 exit: 111 rte_kvargs_free(kvlist); 112 113 return ret; 114 } 115 116 static int 117 mvtvm_ml_vdev_probe(struct rte_vdev_device *vdev) 118 { 119 struct rte_ml_dev_pmd_init_params init_params; 120 struct mvtvm_ml_dev *mvtvm_mldev; 121 struct cnxk_ml_dev *cnxk_mldev; 122 struct rte_ml_dev *dev; 123 const char *input_args; 124 const char *name; 125 int ret = 0; 126 127 if (cnxk_ml_dev_initialized == 1) { 128 plt_err("ML CNXK device already initialized!"); 129 plt_err("Not creating ml_mvtvm vdev!"); 130 return 0; 131 } 132 133 init_params = (struct rte_ml_dev_pmd_init_params){ 134 .socket_id = rte_socket_id(), .private_data_size = sizeof(struct cnxk_ml_dev)}; 135 136 name = rte_vdev_device_name(vdev); 137 if (name == NULL) 138 return -EINVAL; 139 input_args = rte_vdev_device_args(vdev); 140 141 dev = rte_ml_dev_pmd_create(name, &vdev->device, &init_params); 142 if (dev == NULL) { 143 ret = -EFAULT; 144 goto error_exit; 145 } 146 147 cnxk_mldev = dev->data->dev_private; 148 cnxk_mldev->mldev = dev; 149 mvtvm_mldev = &cnxk_mldev->mvtvm_mldev; 150 mvtvm_mldev->vdev = vdev; 151 152 ret = mvtvm_mldev_parse_devargs(input_args, mvtvm_mldev); 153 if (ret < 0) 154 goto error_exit; 155 156 dev->dev_ops = &cnxk_ml_ops; 157 dev->enqueue_burst = NULL; 158 dev->dequeue_burst = NULL; 159 dev->op_error_get = NULL; 160 161 cnxk_ml_dev_initialized = 1; 162 cnxk_mldev->type = CNXK_ML_DEV_TYPE_VDEV; 163 164 return 0; 165 166 error_exit: 167 plt_err("Could not create device: ml_mvtvm"); 168 169 return ret; 170 } 171 172 static int 173 mvtvm_ml_vdev_remove(struct rte_vdev_device *vdev) 174 { 175 struct rte_ml_dev *dev; 176 const char *name; 177 178 name = rte_vdev_device_name(vdev); 179 if (name == NULL) 180 return -EINVAL; 181 182 dev = rte_ml_dev_pmd_get_named_dev(name); 183 if (dev == NULL) 184 return -ENODEV; 185 186 return rte_ml_dev_pmd_destroy(dev); 187 } 188 189 static struct rte_vdev_driver mvtvm_mldev_pmd = {.probe = mvtvm_ml_vdev_probe, 190 .remove = mvtvm_ml_vdev_remove}; 191 192 RTE_PMD_REGISTER_VDEV(MLDEV_NAME_MVTVM_PMD, mvtvm_mldev_pmd); 193 194 RTE_PMD_REGISTER_PARAM_STRING(MLDEV_NAME_MVTVM_PMD, 195 MVTVM_ML_DEV_MAX_QPS "=<int>" MVTVM_ML_DEV_CACHE_MODEL_DATA "=<0|1>"); 196