xref: /dpdk/app/test-mldev/test_model_common.c (revision 62774b78a84e9fa5df56d04cffed69bef8c901f1)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2022 Marvell.
3  */
4 
5 #include <errno.h>
6 
7 #include <rte_common.h>
8 #include <rte_malloc.h>
9 #include <rte_mldev.h>
10 
11 #include "ml_common.h"
12 #include "test_model_common.h"
13 
14 int
15 ml_model_load(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
16 {
17 	struct rte_ml_model_params model_params;
18 	int ret;
19 
20 	RTE_SET_USED(test);
21 
22 	if (model->state == MODEL_LOADED)
23 		return 0;
24 
25 	if (model->state != MODEL_INITIAL)
26 		return -EINVAL;
27 
28 	/* read model binary */
29 	ret = ml_read_file(opt->filelist[fid].model, &model_params.size,
30 			   (char **)&model_params.addr);
31 	if (ret != 0)
32 		return ret;
33 
34 	/* load model to device */
35 	ret = rte_ml_model_load(opt->dev_id, &model_params, &model->id);
36 	if (ret != 0) {
37 		ml_err("Failed to load model : %s\n", opt->filelist[fid].model);
38 		model->state = MODEL_ERROR;
39 		free(model_params.addr);
40 		return ret;
41 	}
42 
43 	/* release buffer */
44 	free(model_params.addr);
45 
46 	/* get model info */
47 	ret = rte_ml_model_info_get(opt->dev_id, model->id, &model->info);
48 	if (ret != 0) {
49 		ml_err("Failed to get model info : %s\n", opt->filelist[fid].model);
50 		return ret;
51 	}
52 
53 	/* Update number of batches */
54 	if (opt->batches == 0)
55 		model->nb_batches = model->info.batch_size;
56 	else
57 		model->nb_batches = opt->batches;
58 
59 	model->state = MODEL_LOADED;
60 
61 	return 0;
62 }
63 
64 int
65 ml_model_unload(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
66 {
67 	struct test_common *t = ml_test_priv(test);
68 	int ret;
69 
70 	RTE_SET_USED(t);
71 
72 	if (model->state == MODEL_INITIAL)
73 		return 0;
74 
75 	if (model->state != MODEL_LOADED)
76 		return -EINVAL;
77 
78 	/* unload model */
79 	ret = rte_ml_model_unload(opt->dev_id, model->id);
80 	if (ret != 0) {
81 		ml_err("Failed to unload model: %s\n", opt->filelist[fid].model);
82 		model->state = MODEL_ERROR;
83 		return ret;
84 	}
85 
86 	model->state = MODEL_INITIAL;
87 
88 	return 0;
89 }
90 
91 int
92 ml_model_start(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
93 {
94 	struct test_common *t = ml_test_priv(test);
95 	int ret;
96 
97 	RTE_SET_USED(t);
98 
99 	if (model->state == MODEL_STARTED)
100 		return 0;
101 
102 	if (model->state != MODEL_LOADED)
103 		return -EINVAL;
104 
105 	/* start model */
106 	ret = rte_ml_model_start(opt->dev_id, model->id);
107 	if (ret != 0) {
108 		ml_err("Failed to start model : %s\n", opt->filelist[fid].model);
109 		model->state = MODEL_ERROR;
110 		return ret;
111 	}
112 
113 	model->state = MODEL_STARTED;
114 
115 	return 0;
116 }
117 
118 int
119 ml_model_stop(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
120 {
121 	struct test_common *t = ml_test_priv(test);
122 	int ret;
123 
124 	RTE_SET_USED(t);
125 
126 	if (model->state == MODEL_LOADED)
127 		return 0;
128 
129 	if (model->state != MODEL_STARTED)
130 		return -EINVAL;
131 
132 	/* stop model */
133 	ret = rte_ml_model_stop(opt->dev_id, model->id);
134 	if (ret != 0) {
135 		ml_err("Failed to stop model: %s\n", opt->filelist[fid].model);
136 		model->state = MODEL_ERROR;
137 		return ret;
138 	}
139 
140 	model->state = MODEL_LOADED;
141 
142 	return 0;
143 }
144