xref: /dpdk/app/test-mldev/test_model_common.c (revision 24364292061f197a9f608b1020b6886ec7b1216d)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2022 Marvell.
3  */
4 
5 #include <errno.h>
6 
7 #include <rte_common.h>
8 #include <rte_malloc.h>
9 #include <rte_mldev.h>
10 
11 #include "ml_common.h"
12 #include "test_model_common.h"
13 
14 int
ml_model_load(struct ml_test * test,struct ml_options * opt,struct ml_model * model,uint16_t fid)15 ml_model_load(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
16 {
17 	struct rte_ml_model_params model_params;
18 	int ret;
19 
20 	RTE_SET_USED(test);
21 
22 	if (model->state == MODEL_LOADED)
23 		return 0;
24 
25 	if (model->state != MODEL_INITIAL)
26 		return -EINVAL;
27 
28 	/* read model binary */
29 	ret = ml_read_file(opt->filelist[fid].model, &model_params.size,
30 			   (char **)&model_params.addr);
31 	if (ret != 0)
32 		return ret;
33 
34 	/* load model to device */
35 	ret = rte_ml_model_load(opt->dev_id, &model_params, &model->id);
36 	if (ret != 0) {
37 		ml_err("Failed to load model : %s\n", opt->filelist[fid].model);
38 		model->state = MODEL_ERROR;
39 		free(model_params.addr);
40 		return ret;
41 	}
42 
43 	/* release buffer */
44 	free(model_params.addr);
45 
46 	/* get model info */
47 	ret = rte_ml_model_info_get(opt->dev_id, model->id, &model->info);
48 	if (ret != 0) {
49 		ml_err("Failed to get model info : %s\n", opt->filelist[fid].model);
50 		return ret;
51 	}
52 
53 	model->state = MODEL_LOADED;
54 
55 	return 0;
56 }
57 
58 int
ml_model_unload(struct ml_test * test,struct ml_options * opt,struct ml_model * model,uint16_t fid)59 ml_model_unload(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
60 {
61 	struct test_common *t = ml_test_priv(test);
62 	int ret;
63 
64 	RTE_SET_USED(t);
65 
66 	if (model->state == MODEL_INITIAL)
67 		return 0;
68 
69 	if (model->state != MODEL_LOADED)
70 		return -EINVAL;
71 
72 	/* unload model */
73 	ret = rte_ml_model_unload(opt->dev_id, model->id);
74 	if (ret != 0) {
75 		ml_err("Failed to unload model: %s\n", opt->filelist[fid].model);
76 		model->state = MODEL_ERROR;
77 		return ret;
78 	}
79 
80 	model->state = MODEL_INITIAL;
81 
82 	return 0;
83 }
84 
85 int
ml_model_start(struct ml_test * test,struct ml_options * opt,struct ml_model * model,uint16_t fid)86 ml_model_start(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
87 {
88 	struct test_common *t = ml_test_priv(test);
89 	int ret;
90 
91 	RTE_SET_USED(t);
92 
93 	if (model->state == MODEL_STARTED)
94 		return 0;
95 
96 	if (model->state != MODEL_LOADED)
97 		return -EINVAL;
98 
99 	/* start model */
100 	ret = rte_ml_model_start(opt->dev_id, model->id);
101 	if (ret != 0) {
102 		ml_err("Failed to start model : %s\n", opt->filelist[fid].model);
103 		model->state = MODEL_ERROR;
104 		return ret;
105 	}
106 
107 	model->state = MODEL_STARTED;
108 
109 	return 0;
110 }
111 
112 int
ml_model_stop(struct ml_test * test,struct ml_options * opt,struct ml_model * model,uint16_t fid)113 ml_model_stop(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
114 {
115 	struct test_common *t = ml_test_priv(test);
116 	int ret;
117 
118 	RTE_SET_USED(t);
119 
120 	if (model->state == MODEL_LOADED)
121 		return 0;
122 
123 	if (model->state != MODEL_STARTED)
124 		return -EINVAL;
125 
126 	/* stop model */
127 	ret = rte_ml_model_stop(opt->dev_id, model->id);
128 	if (ret != 0) {
129 		ml_err("Failed to stop model: %s\n", opt->filelist[fid].model);
130 		model->state = MODEL_ERROR;
131 		return ret;
132 	}
133 
134 	model->state = MODEL_LOADED;
135 
136 	return 0;
137 }
138