xref: /dpdk/app/test-mldev/test_model_common.c (revision 2bf48044dca1892e571fd4964eecaacf6cb0c1c2)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2022 Marvell.
3  */
4 
5 #include <errno.h>
6 
7 #include <rte_common.h>
8 #include <rte_malloc.h>
9 #include <rte_mldev.h>
10 
11 #include "ml_common.h"
12 #include "test_model_common.h"
13 
14 int
15 ml_model_load(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
16 {
17 	struct test_common *t = ml_test_priv(test);
18 	struct rte_ml_model_params model_params;
19 	FILE *fp;
20 	int ret;
21 
22 	if (model->state == MODEL_LOADED)
23 		return 0;
24 
25 	if (model->state != MODEL_INITIAL)
26 		return -EINVAL;
27 
28 	/* read model binary */
29 	fp = fopen(opt->filelist[fid].model, "r");
30 	if (fp == NULL) {
31 		ml_err("Failed to open model file : %s\n", opt->filelist[fid].model);
32 		return -1;
33 	}
34 
35 	fseek(fp, 0, SEEK_END);
36 	model_params.size = ftell(fp);
37 	fseek(fp, 0, SEEK_SET);
38 
39 	model_params.addr = rte_malloc_socket("ml_model", model_params.size,
40 					      t->dev_info.min_align_size, opt->socket_id);
41 	if (model_params.addr == NULL) {
42 		ml_err("Failed to allocate memory for model: %s\n", opt->filelist[fid].model);
43 		fclose(fp);
44 		return -ENOMEM;
45 	}
46 
47 	if (fread(model_params.addr, 1, model_params.size, fp) != model_params.size) {
48 		ml_err("Failed to read model file : %s\n", opt->filelist[fid].model);
49 		rte_free(model_params.addr);
50 		fclose(fp);
51 		return -1;
52 	}
53 	fclose(fp);
54 
55 	/* load model to device */
56 	ret = rte_ml_model_load(opt->dev_id, &model_params, &model->id);
57 	if (ret != 0) {
58 		ml_err("Failed to load model : %s\n", opt->filelist[fid].model);
59 		model->state = MODEL_ERROR;
60 		rte_free(model_params.addr);
61 		return ret;
62 	}
63 
64 	/* release mz */
65 	rte_free(model_params.addr);
66 
67 	/* get model info */
68 	ret = rte_ml_model_info_get(opt->dev_id, model->id, &model->info);
69 	if (ret != 0) {
70 		ml_err("Failed to get model info : %s\n", opt->filelist[fid].model);
71 		return ret;
72 	}
73 
74 	/* Update number of batches */
75 	if (opt->batches == 0)
76 		model->nb_batches = model->info.batch_size;
77 	else
78 		model->nb_batches = opt->batches;
79 
80 	model->state = MODEL_LOADED;
81 
82 	return 0;
83 }
84 
85 int
86 ml_model_unload(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
87 {
88 	struct test_common *t = ml_test_priv(test);
89 	int ret;
90 
91 	RTE_SET_USED(t);
92 
93 	if (model->state == MODEL_INITIAL)
94 		return 0;
95 
96 	if (model->state != MODEL_LOADED)
97 		return -EINVAL;
98 
99 	/* unload model */
100 	ret = rte_ml_model_unload(opt->dev_id, model->id);
101 	if (ret != 0) {
102 		ml_err("Failed to unload model: %s\n", opt->filelist[fid].model);
103 		model->state = MODEL_ERROR;
104 		return ret;
105 	}
106 
107 	model->state = MODEL_INITIAL;
108 
109 	return 0;
110 }
111 
112 int
113 ml_model_start(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
114 {
115 	struct test_common *t = ml_test_priv(test);
116 	int ret;
117 
118 	RTE_SET_USED(t);
119 
120 	if (model->state == MODEL_STARTED)
121 		return 0;
122 
123 	if (model->state != MODEL_LOADED)
124 		return -EINVAL;
125 
126 	/* start model */
127 	ret = rte_ml_model_start(opt->dev_id, model->id);
128 	if (ret != 0) {
129 		ml_err("Failed to start model : %s\n", opt->filelist[fid].model);
130 		model->state = MODEL_ERROR;
131 		return ret;
132 	}
133 
134 	model->state = MODEL_STARTED;
135 
136 	return 0;
137 }
138 
139 int
140 ml_model_stop(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid)
141 {
142 	struct test_common *t = ml_test_priv(test);
143 	int ret;
144 
145 	RTE_SET_USED(t);
146 
147 	if (model->state == MODEL_LOADED)
148 		return 0;
149 
150 	if (model->state != MODEL_STARTED)
151 		return -EINVAL;
152 
153 	/* stop model */
154 	ret = rte_ml_model_stop(opt->dev_id, model->id);
155 	if (ret != 0) {
156 		ml_err("Failed to stop model: %s\n", opt->filelist[fid].model);
157 		model->state = MODEL_ERROR;
158 		return ret;
159 	}
160 
161 	model->state = MODEL_LOADED;
162 
163 	return 0;
164 }
165