1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright(c) 2010-2016 Intel Corporation 3 */ 4 5 #include <sys/socket.h> 6 #include <sys/types.h> 7 #include <sys/stat.h> 8 #include <unistd.h> 9 #include <fcntl.h> 10 #include <sys/un.h> 11 #include <string.h> 12 #include <errno.h> 13 14 #include <rte_string_fns.h> 15 #include <rte_fbarray.h> 16 17 #include "vhost.h" 18 #include "virtio_user_dev.h" 19 20 /* The version of the protocol we support */ 21 #define VHOST_USER_VERSION 0x1 22 23 #define VHOST_MEMORY_MAX_NREGIONS 8 24 struct vhost_memory { 25 uint32_t nregions; 26 uint32_t padding; 27 struct vhost_memory_region regions[VHOST_MEMORY_MAX_NREGIONS]; 28 }; 29 30 struct vhost_user_msg { 31 enum vhost_user_request request; 32 33 #define VHOST_USER_VERSION_MASK 0x3 34 #define VHOST_USER_REPLY_MASK (0x1 << 2) 35 uint32_t flags; 36 uint32_t size; /* the following payload size */ 37 union { 38 #define VHOST_USER_VRING_IDX_MASK 0xff 39 #define VHOST_USER_VRING_NOFD_MASK (0x1 << 8) 40 uint64_t u64; 41 struct vhost_vring_state state; 42 struct vhost_vring_addr addr; 43 struct vhost_memory memory; 44 } payload; 45 int fds[VHOST_MEMORY_MAX_NREGIONS]; 46 } __attribute((packed)); 47 48 #define VHOST_USER_HDR_SIZE offsetof(struct vhost_user_msg, payload.u64) 49 #define VHOST_USER_PAYLOAD_SIZE \ 50 (sizeof(struct vhost_user_msg) - VHOST_USER_HDR_SIZE) 51 52 static int 53 vhost_user_write(int fd, void *buf, int len, int *fds, int fd_num) 54 { 55 int r; 56 struct msghdr msgh; 57 struct iovec iov; 58 size_t fd_size = fd_num * sizeof(int); 59 char control[CMSG_SPACE(fd_size)]; 60 struct cmsghdr *cmsg; 61 62 memset(&msgh, 0, sizeof(msgh)); 63 memset(control, 0, sizeof(control)); 64 65 iov.iov_base = (uint8_t *)buf; 66 iov.iov_len = len; 67 68 msgh.msg_iov = &iov; 69 msgh.msg_iovlen = 1; 70 msgh.msg_control = control; 71 msgh.msg_controllen = sizeof(control); 72 73 cmsg = CMSG_FIRSTHDR(&msgh); 74 cmsg->cmsg_len = CMSG_LEN(fd_size); 75 cmsg->cmsg_level = SOL_SOCKET; 76 cmsg->cmsg_type = SCM_RIGHTS; 77 memcpy(CMSG_DATA(cmsg), fds, fd_size); 78 79 do { 80 r = sendmsg(fd, &msgh, 0); 81 } while (r < 0 && errno == EINTR); 82 83 return r; 84 } 85 86 static int 87 vhost_user_read(int fd, struct vhost_user_msg *msg) 88 { 89 uint32_t valid_flags = VHOST_USER_REPLY_MASK | VHOST_USER_VERSION; 90 int ret, sz_hdr = VHOST_USER_HDR_SIZE, sz_payload; 91 92 ret = recv(fd, (void *)msg, sz_hdr, 0); 93 if (ret < sz_hdr) { 94 PMD_DRV_LOG(ERR, "Failed to recv msg hdr: %d instead of %d.", 95 ret, sz_hdr); 96 goto fail; 97 } 98 99 /* validate msg flags */ 100 if (msg->flags != (valid_flags)) { 101 PMD_DRV_LOG(ERR, "Failed to recv msg: flags %x instead of %x.", 102 msg->flags, valid_flags); 103 goto fail; 104 } 105 106 sz_payload = msg->size; 107 108 if ((size_t)sz_payload > sizeof(msg->payload)) 109 goto fail; 110 111 if (sz_payload) { 112 ret = recv(fd, (void *)((char *)msg + sz_hdr), sz_payload, 0); 113 if (ret < sz_payload) { 114 PMD_DRV_LOG(ERR, 115 "Failed to recv msg payload: %d instead of %d.", 116 ret, msg->size); 117 goto fail; 118 } 119 } 120 121 return 0; 122 123 fail: 124 return -1; 125 } 126 127 struct walk_arg { 128 struct vhost_memory *vm; 129 int *fds; 130 int region_nr; 131 }; 132 133 static int 134 update_memory_region(const struct rte_memseg_list *msl __rte_unused, 135 const struct rte_memseg *ms, void *arg) 136 { 137 struct walk_arg *wa = arg; 138 struct vhost_memory_region *mr; 139 uint64_t start_addr, end_addr; 140 size_t offset; 141 int i, fd; 142 143 fd = rte_memseg_get_fd_thread_unsafe(ms); 144 if (fd < 0) { 145 PMD_DRV_LOG(ERR, "Failed to get fd, ms=%p rte_errno=%d", 146 ms, rte_errno); 147 return -1; 148 } 149 150 if (rte_memseg_get_fd_offset_thread_unsafe(ms, &offset) < 0) { 151 PMD_DRV_LOG(ERR, "Failed to get offset, ms=%p rte_errno=%d", 152 ms, rte_errno); 153 return -1; 154 } 155 156 start_addr = (uint64_t)(uintptr_t)ms->addr; 157 end_addr = start_addr + ms->len; 158 159 for (i = 0; i < wa->region_nr; i++) { 160 if (wa->fds[i] != fd) 161 continue; 162 163 mr = &wa->vm->regions[i]; 164 165 if (mr->userspace_addr + mr->memory_size < end_addr) 166 mr->memory_size = end_addr - mr->userspace_addr; 167 168 if (mr->userspace_addr > start_addr) { 169 mr->userspace_addr = start_addr; 170 mr->guest_phys_addr = start_addr; 171 } 172 173 if (mr->mmap_offset > offset) 174 mr->mmap_offset = offset; 175 176 PMD_DRV_LOG(DEBUG, "index=%d fd=%d offset=0x%" PRIx64 177 " addr=0x%" PRIx64 " len=%" PRIu64, i, fd, 178 mr->mmap_offset, mr->userspace_addr, 179 mr->memory_size); 180 181 return 0; 182 } 183 184 if (i >= VHOST_MEMORY_MAX_NREGIONS) { 185 PMD_DRV_LOG(ERR, "Too many memory regions"); 186 return -1; 187 } 188 189 mr = &wa->vm->regions[i]; 190 wa->fds[i] = fd; 191 192 mr->guest_phys_addr = start_addr; 193 mr->userspace_addr = start_addr; 194 mr->memory_size = ms->len; 195 mr->mmap_offset = offset; 196 197 PMD_DRV_LOG(DEBUG, "index=%d fd=%d offset=0x%" PRIx64 198 " addr=0x%" PRIx64 " len=%" PRIu64, i, fd, 199 mr->mmap_offset, mr->userspace_addr, 200 mr->memory_size); 201 202 wa->region_nr++; 203 204 return 0; 205 } 206 207 static int 208 prepare_vhost_memory_user(struct vhost_user_msg *msg, int fds[]) 209 { 210 struct walk_arg wa; 211 212 wa.region_nr = 0; 213 wa.vm = &msg->payload.memory; 214 wa.fds = fds; 215 216 /* 217 * The memory lock has already been taken by memory subsystem 218 * or virtio_user_start_device(). 219 */ 220 if (rte_memseg_walk_thread_unsafe(update_memory_region, &wa) < 0) 221 return -1; 222 223 msg->payload.memory.nregions = wa.region_nr; 224 msg->payload.memory.padding = 0; 225 226 return 0; 227 } 228 229 static struct vhost_user_msg m; 230 231 const char * const vhost_msg_strings[] = { 232 [VHOST_USER_SET_OWNER] = "VHOST_SET_OWNER", 233 [VHOST_USER_RESET_OWNER] = "VHOST_RESET_OWNER", 234 [VHOST_USER_SET_FEATURES] = "VHOST_SET_FEATURES", 235 [VHOST_USER_GET_FEATURES] = "VHOST_GET_FEATURES", 236 [VHOST_USER_SET_VRING_CALL] = "VHOST_SET_VRING_CALL", 237 [VHOST_USER_SET_VRING_NUM] = "VHOST_SET_VRING_NUM", 238 [VHOST_USER_SET_VRING_BASE] = "VHOST_SET_VRING_BASE", 239 [VHOST_USER_GET_VRING_BASE] = "VHOST_GET_VRING_BASE", 240 [VHOST_USER_SET_VRING_ADDR] = "VHOST_SET_VRING_ADDR", 241 [VHOST_USER_SET_VRING_KICK] = "VHOST_SET_VRING_KICK", 242 [VHOST_USER_SET_MEM_TABLE] = "VHOST_SET_MEM_TABLE", 243 [VHOST_USER_SET_VRING_ENABLE] = "VHOST_SET_VRING_ENABLE", 244 }; 245 246 static int 247 vhost_user_sock(struct virtio_user_dev *dev, 248 enum vhost_user_request req, 249 void *arg) 250 { 251 struct vhost_user_msg msg; 252 struct vhost_vring_file *file = 0; 253 int need_reply = 0; 254 int fds[VHOST_MEMORY_MAX_NREGIONS]; 255 int fd_num = 0; 256 int len; 257 int vhostfd = dev->vhostfd; 258 259 RTE_SET_USED(m); 260 261 PMD_DRV_LOG(INFO, "%s", vhost_msg_strings[req]); 262 263 if (dev->is_server && vhostfd < 0) 264 return -1; 265 266 msg.request = req; 267 msg.flags = VHOST_USER_VERSION; 268 msg.size = 0; 269 270 switch (req) { 271 case VHOST_USER_GET_FEATURES: 272 need_reply = 1; 273 break; 274 275 case VHOST_USER_SET_FEATURES: 276 case VHOST_USER_SET_LOG_BASE: 277 msg.payload.u64 = *((__u64 *)arg); 278 msg.size = sizeof(m.payload.u64); 279 break; 280 281 case VHOST_USER_SET_OWNER: 282 case VHOST_USER_RESET_OWNER: 283 break; 284 285 case VHOST_USER_SET_MEM_TABLE: 286 if (prepare_vhost_memory_user(&msg, fds) < 0) 287 return -1; 288 fd_num = msg.payload.memory.nregions; 289 msg.size = sizeof(m.payload.memory.nregions); 290 msg.size += sizeof(m.payload.memory.padding); 291 msg.size += fd_num * sizeof(struct vhost_memory_region); 292 break; 293 294 case VHOST_USER_SET_LOG_FD: 295 fds[fd_num++] = *((int *)arg); 296 break; 297 298 case VHOST_USER_SET_VRING_NUM: 299 case VHOST_USER_SET_VRING_BASE: 300 case VHOST_USER_SET_VRING_ENABLE: 301 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state)); 302 msg.size = sizeof(m.payload.state); 303 break; 304 305 case VHOST_USER_GET_VRING_BASE: 306 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state)); 307 msg.size = sizeof(m.payload.state); 308 need_reply = 1; 309 break; 310 311 case VHOST_USER_SET_VRING_ADDR: 312 memcpy(&msg.payload.addr, arg, sizeof(msg.payload.addr)); 313 msg.size = sizeof(m.payload.addr); 314 break; 315 316 case VHOST_USER_SET_VRING_KICK: 317 case VHOST_USER_SET_VRING_CALL: 318 case VHOST_USER_SET_VRING_ERR: 319 file = arg; 320 msg.payload.u64 = file->index & VHOST_USER_VRING_IDX_MASK; 321 msg.size = sizeof(m.payload.u64); 322 if (file->fd > 0) 323 fds[fd_num++] = file->fd; 324 else 325 msg.payload.u64 |= VHOST_USER_VRING_NOFD_MASK; 326 break; 327 328 default: 329 PMD_DRV_LOG(ERR, "trying to send unhandled msg type"); 330 return -1; 331 } 332 333 len = VHOST_USER_HDR_SIZE + msg.size; 334 if (vhost_user_write(vhostfd, &msg, len, fds, fd_num) < 0) { 335 PMD_DRV_LOG(ERR, "%s failed: %s", 336 vhost_msg_strings[req], strerror(errno)); 337 return -1; 338 } 339 340 if (need_reply) { 341 if (vhost_user_read(vhostfd, &msg) < 0) { 342 PMD_DRV_LOG(ERR, "Received msg failed: %s", 343 strerror(errno)); 344 return -1; 345 } 346 347 if (req != msg.request) { 348 PMD_DRV_LOG(ERR, "Received unexpected msg type"); 349 return -1; 350 } 351 352 switch (req) { 353 case VHOST_USER_GET_FEATURES: 354 if (msg.size != sizeof(m.payload.u64)) { 355 PMD_DRV_LOG(ERR, "Received bad msg size"); 356 return -1; 357 } 358 *((__u64 *)arg) = msg.payload.u64; 359 break; 360 case VHOST_USER_GET_VRING_BASE: 361 if (msg.size != sizeof(m.payload.state)) { 362 PMD_DRV_LOG(ERR, "Received bad msg size"); 363 return -1; 364 } 365 memcpy(arg, &msg.payload.state, 366 sizeof(struct vhost_vring_state)); 367 break; 368 default: 369 PMD_DRV_LOG(ERR, "Received unexpected msg type"); 370 return -1; 371 } 372 } 373 374 return 0; 375 } 376 377 #define MAX_VIRTIO_USER_BACKLOG 1 378 static int 379 virtio_user_start_server(struct virtio_user_dev *dev, struct sockaddr_un *un) 380 { 381 int ret; 382 int flag; 383 int fd = dev->listenfd; 384 385 ret = bind(fd, (struct sockaddr *)un, sizeof(*un)); 386 if (ret < 0) { 387 PMD_DRV_LOG(ERR, "failed to bind to %s: %s; remove it and try again\n", 388 dev->path, strerror(errno)); 389 return -1; 390 } 391 ret = listen(fd, MAX_VIRTIO_USER_BACKLOG); 392 if (ret < 0) 393 return -1; 394 395 flag = fcntl(fd, F_GETFL); 396 if (fcntl(fd, F_SETFL, flag | O_NONBLOCK) < 0) { 397 PMD_DRV_LOG(ERR, "fcntl failed, %s", strerror(errno)); 398 return -1; 399 } 400 401 return 0; 402 } 403 404 /** 405 * Set up environment to talk with a vhost user backend. 406 * 407 * @return 408 * - (-1) if fail; 409 * - (0) if succeed. 410 */ 411 static int 412 vhost_user_setup(struct virtio_user_dev *dev) 413 { 414 int fd; 415 int flag; 416 struct sockaddr_un un; 417 418 fd = socket(AF_UNIX, SOCK_STREAM, 0); 419 if (fd < 0) { 420 PMD_DRV_LOG(ERR, "socket() error, %s", strerror(errno)); 421 return -1; 422 } 423 424 flag = fcntl(fd, F_GETFD); 425 if (fcntl(fd, F_SETFD, flag | FD_CLOEXEC) < 0) 426 PMD_DRV_LOG(WARNING, "fcntl failed, %s", strerror(errno)); 427 428 memset(&un, 0, sizeof(un)); 429 un.sun_family = AF_UNIX; 430 strlcpy(un.sun_path, dev->path, sizeof(un.sun_path)); 431 432 if (dev->is_server) { 433 dev->listenfd = fd; 434 if (virtio_user_start_server(dev, &un) < 0) { 435 PMD_DRV_LOG(ERR, "virtio-user startup fails in server mode"); 436 close(fd); 437 return -1; 438 } 439 dev->vhostfd = -1; 440 } else { 441 if (connect(fd, (struct sockaddr *)&un, sizeof(un)) < 0) { 442 PMD_DRV_LOG(ERR, "connect error, %s", strerror(errno)); 443 close(fd); 444 return -1; 445 } 446 dev->vhostfd = fd; 447 } 448 449 return 0; 450 } 451 452 static int 453 vhost_user_enable_queue_pair(struct virtio_user_dev *dev, 454 uint16_t pair_idx, 455 int enable) 456 { 457 int i; 458 459 for (i = 0; i < 2; ++i) { 460 struct vhost_vring_state state = { 461 .index = pair_idx * 2 + i, 462 .num = enable, 463 }; 464 465 if (vhost_user_sock(dev, VHOST_USER_SET_VRING_ENABLE, &state)) 466 return -1; 467 } 468 469 return 0; 470 } 471 472 struct virtio_user_backend_ops virtio_ops_user = { 473 .setup = vhost_user_setup, 474 .send_request = vhost_user_sock, 475 .enable_qp = vhost_user_enable_queue_pair 476 }; 477