1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright (c) 2022 Marvell. 3 */ 4 5 #include <errno.h> 6 7 #include <rte_common.h> 8 #include <rte_malloc.h> 9 #include <rte_mldev.h> 10 11 #include "ml_common.h" 12 #include "test_model_common.h" 13 14 int 15 ml_model_load(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid) 16 { 17 struct rte_ml_model_params model_params; 18 int ret; 19 20 RTE_SET_USED(test); 21 22 if (model->state == MODEL_LOADED) 23 return 0; 24 25 if (model->state != MODEL_INITIAL) 26 return -EINVAL; 27 28 /* read model binary */ 29 ret = ml_read_file(opt->filelist[fid].model, &model_params.size, 30 (char **)&model_params.addr); 31 if (ret != 0) 32 return ret; 33 34 /* load model to device */ 35 ret = rte_ml_model_load(opt->dev_id, &model_params, &model->id); 36 if (ret != 0) { 37 ml_err("Failed to load model : %s\n", opt->filelist[fid].model); 38 model->state = MODEL_ERROR; 39 free(model_params.addr); 40 return ret; 41 } 42 43 /* release buffer */ 44 free(model_params.addr); 45 46 /* get model info */ 47 ret = rte_ml_model_info_get(opt->dev_id, model->id, &model->info); 48 if (ret != 0) { 49 ml_err("Failed to get model info : %s\n", opt->filelist[fid].model); 50 return ret; 51 } 52 53 /* Update number of batches */ 54 if (opt->batches == 0) 55 model->nb_batches = model->info.batch_size; 56 else 57 model->nb_batches = opt->batches; 58 59 model->state = MODEL_LOADED; 60 61 return 0; 62 } 63 64 int 65 ml_model_unload(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid) 66 { 67 struct test_common *t = ml_test_priv(test); 68 int ret; 69 70 RTE_SET_USED(t); 71 72 if (model->state == MODEL_INITIAL) 73 return 0; 74 75 if (model->state != MODEL_LOADED) 76 return -EINVAL; 77 78 /* unload model */ 79 ret = rte_ml_model_unload(opt->dev_id, model->id); 80 if (ret != 0) { 81 ml_err("Failed to unload model: %s\n", opt->filelist[fid].model); 82 model->state = MODEL_ERROR; 83 return ret; 84 } 85 86 model->state = MODEL_INITIAL; 87 88 return 0; 89 } 90 91 int 92 ml_model_start(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid) 93 { 94 struct test_common *t = ml_test_priv(test); 95 int ret; 96 97 RTE_SET_USED(t); 98 99 if (model->state == MODEL_STARTED) 100 return 0; 101 102 if (model->state != MODEL_LOADED) 103 return -EINVAL; 104 105 /* start model */ 106 ret = rte_ml_model_start(opt->dev_id, model->id); 107 if (ret != 0) { 108 ml_err("Failed to start model : %s\n", opt->filelist[fid].model); 109 model->state = MODEL_ERROR; 110 return ret; 111 } 112 113 model->state = MODEL_STARTED; 114 115 return 0; 116 } 117 118 int 119 ml_model_stop(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid) 120 { 121 struct test_common *t = ml_test_priv(test); 122 int ret; 123 124 RTE_SET_USED(t); 125 126 if (model->state == MODEL_LOADED) 127 return 0; 128 129 if (model->state != MODEL_STARTED) 130 return -EINVAL; 131 132 /* stop model */ 133 ret = rte_ml_model_stop(opt->dev_id, model->id); 134 if (ret != 0) { 135 ml_err("Failed to stop model: %s\n", opt->filelist[fid].model); 136 model->state = MODEL_ERROR; 137 return ret; 138 } 139 140 model->state = MODEL_LOADED; 141 142 return 0; 143 } 144