1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright(c) 2010-2016 Intel Corporation 3 */ 4 5 #include <sys/socket.h> 6 #include <sys/types.h> 7 #include <sys/stat.h> 8 #include <unistd.h> 9 #include <fcntl.h> 10 #include <sys/un.h> 11 #include <string.h> 12 #include <errno.h> 13 14 #include <rte_string_fns.h> 15 #include <rte_fbarray.h> 16 17 #include "vhost.h" 18 #include "virtio_user_dev.h" 19 20 /* The version of the protocol we support */ 21 #define VHOST_USER_VERSION 0x1 22 23 #define VHOST_MEMORY_MAX_NREGIONS 8 24 struct vhost_memory { 25 uint32_t nregions; 26 uint32_t padding; 27 struct vhost_memory_region regions[VHOST_MEMORY_MAX_NREGIONS]; 28 }; 29 30 struct vhost_user_msg { 31 enum vhost_user_request request; 32 33 #define VHOST_USER_VERSION_MASK 0x3 34 #define VHOST_USER_REPLY_MASK (0x1 << 2) 35 #define VHOST_USER_NEED_REPLY_MASK (0x1 << 3) 36 uint32_t flags; 37 uint32_t size; /* the following payload size */ 38 union { 39 #define VHOST_USER_VRING_IDX_MASK 0xff 40 #define VHOST_USER_VRING_NOFD_MASK (0x1 << 8) 41 uint64_t u64; 42 struct vhost_vring_state state; 43 struct vhost_vring_addr addr; 44 struct vhost_memory memory; 45 } payload; 46 int fds[VHOST_MEMORY_MAX_NREGIONS]; 47 } __rte_packed; 48 49 #define VHOST_USER_HDR_SIZE offsetof(struct vhost_user_msg, payload.u64) 50 #define VHOST_USER_PAYLOAD_SIZE \ 51 (sizeof(struct vhost_user_msg) - VHOST_USER_HDR_SIZE) 52 53 static int 54 vhost_user_write(int fd, struct vhost_user_msg *msg, int *fds, int fd_num) 55 { 56 int r; 57 struct msghdr msgh; 58 struct iovec iov; 59 size_t fd_size = fd_num * sizeof(int); 60 char control[CMSG_SPACE(fd_size)]; 61 struct cmsghdr *cmsg; 62 63 memset(&msgh, 0, sizeof(msgh)); 64 memset(control, 0, sizeof(control)); 65 66 iov.iov_base = (uint8_t *)msg; 67 iov.iov_len = VHOST_USER_HDR_SIZE + msg->size; 68 69 msgh.msg_iov = &iov; 70 msgh.msg_iovlen = 1; 71 msgh.msg_control = control; 72 msgh.msg_controllen = sizeof(control); 73 74 cmsg = CMSG_FIRSTHDR(&msgh); 75 cmsg->cmsg_len = CMSG_LEN(fd_size); 76 cmsg->cmsg_level = SOL_SOCKET; 77 cmsg->cmsg_type = SCM_RIGHTS; 78 memcpy(CMSG_DATA(cmsg), fds, fd_size); 79 80 do { 81 r = sendmsg(fd, &msgh, 0); 82 } while (r < 0 && errno == EINTR); 83 84 return r; 85 } 86 87 static int 88 vhost_user_read(int fd, struct vhost_user_msg *msg) 89 { 90 uint32_t valid_flags = VHOST_USER_REPLY_MASK | VHOST_USER_VERSION; 91 int ret, sz_hdr = VHOST_USER_HDR_SIZE, sz_payload; 92 93 ret = recv(fd, (void *)msg, sz_hdr, 0); 94 if (ret < sz_hdr) { 95 PMD_DRV_LOG(ERR, "Failed to recv msg hdr: %d instead of %d.", 96 ret, sz_hdr); 97 goto fail; 98 } 99 100 /* validate msg flags */ 101 if (msg->flags != (valid_flags)) { 102 PMD_DRV_LOG(ERR, "Failed to recv msg: flags %x instead of %x.", 103 msg->flags, valid_flags); 104 goto fail; 105 } 106 107 sz_payload = msg->size; 108 109 if ((size_t)sz_payload > sizeof(msg->payload)) 110 goto fail; 111 112 if (sz_payload) { 113 ret = recv(fd, (void *)((char *)msg + sz_hdr), sz_payload, 0); 114 if (ret < sz_payload) { 115 PMD_DRV_LOG(ERR, 116 "Failed to recv msg payload: %d instead of %d.", 117 ret, msg->size); 118 goto fail; 119 } 120 } 121 122 return 0; 123 124 fail: 125 return -1; 126 } 127 128 static int 129 vhost_user_set_owner(struct virtio_user_dev *dev) 130 { 131 int ret; 132 struct vhost_user_msg msg = { 133 .request = VHOST_USER_SET_OWNER, 134 .flags = VHOST_USER_VERSION, 135 }; 136 137 ret = vhost_user_write(dev->vhostfd, &msg, NULL, 0); 138 if (ret < 0) { 139 PMD_DRV_LOG(ERR, "Failed to set owner"); 140 return -1; 141 } 142 143 return 0; 144 } 145 146 static int 147 vhost_user_get_features(struct virtio_user_dev *dev, uint64_t *features) 148 { 149 int ret; 150 struct vhost_user_msg msg = { 151 .request = VHOST_USER_GET_FEATURES, 152 .flags = VHOST_USER_VERSION, 153 }; 154 155 ret = vhost_user_write(dev->vhostfd, &msg, NULL, 0); 156 if (ret < 0) 157 goto err; 158 159 ret = vhost_user_read(dev->vhostfd, &msg); 160 if (ret < 0) 161 goto err; 162 163 if (msg.request != VHOST_USER_GET_FEATURES) { 164 PMD_DRV_LOG(ERR, "Unexpected request type (%d)", msg.request); 165 goto err; 166 } 167 168 if (msg.size != sizeof(*features)) { 169 PMD_DRV_LOG(ERR, "Unexpected payload size (%u)", msg.size); 170 goto err; 171 } 172 173 *features = msg.payload.u64; 174 175 return 0; 176 err: 177 PMD_DRV_LOG(ERR, "Failed to get backend features"); 178 179 return -1; 180 } 181 182 static int 183 vhost_user_set_features(struct virtio_user_dev *dev, uint64_t features) 184 { 185 int ret; 186 struct vhost_user_msg msg = { 187 .request = VHOST_USER_SET_FEATURES, 188 .flags = VHOST_USER_VERSION, 189 .size = sizeof(features), 190 .payload.u64 = features, 191 }; 192 193 msg.payload.u64 |= dev->device_features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES); 194 195 ret = vhost_user_write(dev->vhostfd, &msg, NULL, 0); 196 if (ret < 0) { 197 PMD_DRV_LOG(ERR, "Failed to set features"); 198 return -1; 199 } 200 201 return 0; 202 } 203 204 static int 205 vhost_user_get_protocol_features(struct virtio_user_dev *dev, uint64_t *features) 206 { 207 int ret; 208 struct vhost_user_msg msg = { 209 .request = VHOST_USER_GET_PROTOCOL_FEATURES, 210 .flags = VHOST_USER_VERSION, 211 }; 212 213 ret = vhost_user_write(dev->vhostfd, &msg, NULL, 0); 214 if (ret < 0) 215 goto err; 216 217 ret = vhost_user_read(dev->vhostfd, &msg); 218 if (ret < 0) 219 goto err; 220 221 if (msg.request != VHOST_USER_GET_PROTOCOL_FEATURES) { 222 PMD_DRV_LOG(ERR, "Unexpected request type (%d)", msg.request); 223 goto err; 224 } 225 226 if (msg.size != sizeof(*features)) { 227 PMD_DRV_LOG(ERR, "Unexpected payload size (%u)", msg.size); 228 goto err; 229 } 230 231 *features = msg.payload.u64; 232 233 return 0; 234 err: 235 PMD_DRV_LOG(ERR, "Failed to get backend protocol features"); 236 237 return -1; 238 } 239 240 static int 241 vhost_user_set_protocol_features(struct virtio_user_dev *dev, uint64_t features) 242 { 243 int ret; 244 struct vhost_user_msg msg = { 245 .request = VHOST_USER_SET_PROTOCOL_FEATURES, 246 .flags = VHOST_USER_VERSION, 247 .size = sizeof(features), 248 .payload.u64 = features, 249 }; 250 251 ret = vhost_user_write(dev->vhostfd, &msg, NULL, 0); 252 if (ret < 0) { 253 PMD_DRV_LOG(ERR, "Failed to set protocol features"); 254 return -1; 255 } 256 257 return 0; 258 } 259 260 struct walk_arg { 261 struct vhost_memory *vm; 262 int *fds; 263 int region_nr; 264 }; 265 266 static int 267 update_memory_region(const struct rte_memseg_list *msl __rte_unused, 268 const struct rte_memseg *ms, void *arg) 269 { 270 struct walk_arg *wa = arg; 271 struct vhost_memory_region *mr; 272 uint64_t start_addr, end_addr; 273 size_t offset; 274 int i, fd; 275 276 fd = rte_memseg_get_fd_thread_unsafe(ms); 277 if (fd < 0) { 278 PMD_DRV_LOG(ERR, "Failed to get fd, ms=%p rte_errno=%d", 279 ms, rte_errno); 280 return -1; 281 } 282 283 if (rte_memseg_get_fd_offset_thread_unsafe(ms, &offset) < 0) { 284 PMD_DRV_LOG(ERR, "Failed to get offset, ms=%p rte_errno=%d", 285 ms, rte_errno); 286 return -1; 287 } 288 289 start_addr = (uint64_t)(uintptr_t)ms->addr; 290 end_addr = start_addr + ms->len; 291 292 for (i = 0; i < wa->region_nr; i++) { 293 if (wa->fds[i] != fd) 294 continue; 295 296 mr = &wa->vm->regions[i]; 297 298 if (mr->userspace_addr + mr->memory_size < end_addr) 299 mr->memory_size = end_addr - mr->userspace_addr; 300 301 if (mr->userspace_addr > start_addr) { 302 mr->userspace_addr = start_addr; 303 mr->guest_phys_addr = start_addr; 304 } 305 306 if (mr->mmap_offset > offset) 307 mr->mmap_offset = offset; 308 309 PMD_DRV_LOG(DEBUG, "index=%d fd=%d offset=0x%" PRIx64 310 " addr=0x%" PRIx64 " len=%" PRIu64, i, fd, 311 mr->mmap_offset, mr->userspace_addr, 312 mr->memory_size); 313 314 return 0; 315 } 316 317 if (i >= VHOST_MEMORY_MAX_NREGIONS) { 318 PMD_DRV_LOG(ERR, "Too many memory regions"); 319 return -1; 320 } 321 322 mr = &wa->vm->regions[i]; 323 wa->fds[i] = fd; 324 325 mr->guest_phys_addr = start_addr; 326 mr->userspace_addr = start_addr; 327 mr->memory_size = ms->len; 328 mr->mmap_offset = offset; 329 330 PMD_DRV_LOG(DEBUG, "index=%d fd=%d offset=0x%" PRIx64 331 " addr=0x%" PRIx64 " len=%" PRIu64, i, fd, 332 mr->mmap_offset, mr->userspace_addr, 333 mr->memory_size); 334 335 wa->region_nr++; 336 337 return 0; 338 } 339 340 static int 341 prepare_vhost_memory_user(struct vhost_user_msg *msg, int fds[]) 342 { 343 struct walk_arg wa; 344 345 wa.region_nr = 0; 346 wa.vm = &msg->payload.memory; 347 wa.fds = fds; 348 349 /* 350 * The memory lock has already been taken by memory subsystem 351 * or virtio_user_start_device(). 352 */ 353 if (rte_memseg_walk_thread_unsafe(update_memory_region, &wa) < 0) 354 return -1; 355 356 msg->payload.memory.nregions = wa.region_nr; 357 msg->payload.memory.padding = 0; 358 359 return 0; 360 } 361 362 static struct vhost_user_msg m; 363 364 const char * const vhost_msg_strings[] = { 365 [VHOST_USER_RESET_OWNER] = "VHOST_RESET_OWNER", 366 [VHOST_USER_SET_VRING_CALL] = "VHOST_SET_VRING_CALL", 367 [VHOST_USER_SET_VRING_NUM] = "VHOST_SET_VRING_NUM", 368 [VHOST_USER_SET_VRING_BASE] = "VHOST_SET_VRING_BASE", 369 [VHOST_USER_GET_VRING_BASE] = "VHOST_GET_VRING_BASE", 370 [VHOST_USER_SET_VRING_ADDR] = "VHOST_SET_VRING_ADDR", 371 [VHOST_USER_SET_VRING_KICK] = "VHOST_SET_VRING_KICK", 372 [VHOST_USER_SET_MEM_TABLE] = "VHOST_SET_MEM_TABLE", 373 [VHOST_USER_SET_VRING_ENABLE] = "VHOST_SET_VRING_ENABLE", 374 [VHOST_USER_SET_STATUS] = "VHOST_SET_STATUS", 375 [VHOST_USER_GET_STATUS] = "VHOST_GET_STATUS", 376 }; 377 378 static int 379 vhost_user_sock(struct virtio_user_dev *dev, 380 enum vhost_user_request req, 381 void *arg) 382 { 383 struct vhost_user_msg msg; 384 struct vhost_vring_file *file = 0; 385 int need_reply = 0; 386 int has_reply_ack = 0; 387 int fds[VHOST_MEMORY_MAX_NREGIONS]; 388 int fd_num = 0; 389 int vhostfd = dev->vhostfd; 390 391 RTE_SET_USED(m); 392 393 PMD_DRV_LOG(INFO, "%s", vhost_msg_strings[req]); 394 395 if (dev->is_server && vhostfd < 0) 396 return -1; 397 398 if (dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK)) 399 has_reply_ack = 1; 400 401 msg.request = req; 402 msg.flags = VHOST_USER_VERSION; 403 msg.size = 0; 404 405 switch (req) { 406 case VHOST_USER_GET_STATUS: 407 if (!(dev->status & VIRTIO_CONFIG_STATUS_FEATURES_OK) || 408 (!(dev->protocol_features & 409 (1ULL << VHOST_USER_PROTOCOL_F_STATUS)))) 410 return -ENOTSUP; 411 need_reply = 1; 412 break; 413 414 case VHOST_USER_SET_STATUS: 415 if (!(dev->status & VIRTIO_CONFIG_STATUS_FEATURES_OK) || 416 (!(dev->protocol_features & 417 (1ULL << VHOST_USER_PROTOCOL_F_STATUS)))) 418 return -ENOTSUP; 419 420 if (has_reply_ack) 421 msg.flags |= VHOST_USER_NEED_REPLY_MASK; 422 /* Fallthrough */ 423 case VHOST_USER_SET_LOG_BASE: 424 msg.payload.u64 = *((__u64 *)arg); 425 msg.size = sizeof(m.payload.u64); 426 break; 427 428 case VHOST_USER_SET_FEATURES: 429 msg.payload.u64 = *((__u64 *)arg) | (dev->device_features & 430 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)); 431 msg.size = sizeof(m.payload.u64); 432 break; 433 434 case VHOST_USER_RESET_OWNER: 435 break; 436 437 case VHOST_USER_SET_MEM_TABLE: 438 if (prepare_vhost_memory_user(&msg, fds) < 0) 439 return -1; 440 fd_num = msg.payload.memory.nregions; 441 msg.size = sizeof(m.payload.memory.nregions); 442 msg.size += sizeof(m.payload.memory.padding); 443 msg.size += fd_num * sizeof(struct vhost_memory_region); 444 445 if (has_reply_ack) 446 msg.flags |= VHOST_USER_NEED_REPLY_MASK; 447 break; 448 449 case VHOST_USER_SET_LOG_FD: 450 fds[fd_num++] = *((int *)arg); 451 break; 452 453 case VHOST_USER_SET_VRING_NUM: 454 case VHOST_USER_SET_VRING_BASE: 455 case VHOST_USER_SET_VRING_ENABLE: 456 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state)); 457 msg.size = sizeof(m.payload.state); 458 break; 459 460 case VHOST_USER_GET_VRING_BASE: 461 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state)); 462 msg.size = sizeof(m.payload.state); 463 need_reply = 1; 464 break; 465 466 case VHOST_USER_SET_VRING_ADDR: 467 memcpy(&msg.payload.addr, arg, sizeof(msg.payload.addr)); 468 msg.size = sizeof(m.payload.addr); 469 break; 470 471 case VHOST_USER_SET_VRING_KICK: 472 case VHOST_USER_SET_VRING_CALL: 473 case VHOST_USER_SET_VRING_ERR: 474 file = arg; 475 msg.payload.u64 = file->index & VHOST_USER_VRING_IDX_MASK; 476 msg.size = sizeof(m.payload.u64); 477 if (file->fd > 0) 478 fds[fd_num++] = file->fd; 479 else 480 msg.payload.u64 |= VHOST_USER_VRING_NOFD_MASK; 481 break; 482 483 default: 484 PMD_DRV_LOG(ERR, "trying to send unhandled msg type"); 485 return -1; 486 } 487 488 if (vhost_user_write(vhostfd, &msg, fds, fd_num) < 0) { 489 PMD_DRV_LOG(ERR, "%s failed: %s", 490 vhost_msg_strings[req], strerror(errno)); 491 return -1; 492 } 493 494 if (need_reply || msg.flags & VHOST_USER_NEED_REPLY_MASK) { 495 if (vhost_user_read(vhostfd, &msg) < 0) { 496 PMD_DRV_LOG(ERR, "Received msg failed: %s", 497 strerror(errno)); 498 return -1; 499 } 500 501 if (req != msg.request) { 502 PMD_DRV_LOG(ERR, "Received unexpected msg type"); 503 return -1; 504 } 505 506 switch (req) { 507 case VHOST_USER_GET_STATUS: 508 if (msg.size != sizeof(m.payload.u64)) { 509 PMD_DRV_LOG(ERR, "Received bad msg size"); 510 return -1; 511 } 512 *((__u64 *)arg) = msg.payload.u64; 513 break; 514 case VHOST_USER_GET_VRING_BASE: 515 if (msg.size != sizeof(m.payload.state)) { 516 PMD_DRV_LOG(ERR, "Received bad msg size"); 517 return -1; 518 } 519 memcpy(arg, &msg.payload.state, 520 sizeof(struct vhost_vring_state)); 521 break; 522 default: 523 /* Reply-ack handling */ 524 if (msg.size != sizeof(m.payload.u64)) { 525 PMD_DRV_LOG(ERR, "Received bad msg size"); 526 return -1; 527 } 528 529 if (msg.payload.u64 != 0) { 530 PMD_DRV_LOG(ERR, "Slave replied NACK"); 531 return -1; 532 } 533 534 break; 535 } 536 } 537 538 return 0; 539 } 540 541 #define MAX_VIRTIO_USER_BACKLOG 1 542 static int 543 virtio_user_start_server(struct virtio_user_dev *dev, struct sockaddr_un *un) 544 { 545 int ret; 546 int flag; 547 int fd = dev->listenfd; 548 549 ret = bind(fd, (struct sockaddr *)un, sizeof(*un)); 550 if (ret < 0) { 551 PMD_DRV_LOG(ERR, "failed to bind to %s: %s; remove it and try again\n", 552 dev->path, strerror(errno)); 553 return -1; 554 } 555 ret = listen(fd, MAX_VIRTIO_USER_BACKLOG); 556 if (ret < 0) 557 return -1; 558 559 flag = fcntl(fd, F_GETFL); 560 if (fcntl(fd, F_SETFL, flag | O_NONBLOCK) < 0) { 561 PMD_DRV_LOG(ERR, "fcntl failed, %s", strerror(errno)); 562 return -1; 563 } 564 565 return 0; 566 } 567 568 /** 569 * Set up environment to talk with a vhost user backend. 570 * 571 * @return 572 * - (-1) if fail; 573 * - (0) if succeed. 574 */ 575 static int 576 vhost_user_setup(struct virtio_user_dev *dev) 577 { 578 int fd; 579 int flag; 580 struct sockaddr_un un; 581 582 fd = socket(AF_UNIX, SOCK_STREAM, 0); 583 if (fd < 0) { 584 PMD_DRV_LOG(ERR, "socket() error, %s", strerror(errno)); 585 return -1; 586 } 587 588 flag = fcntl(fd, F_GETFD); 589 if (fcntl(fd, F_SETFD, flag | FD_CLOEXEC) < 0) 590 PMD_DRV_LOG(WARNING, "fcntl failed, %s", strerror(errno)); 591 592 memset(&un, 0, sizeof(un)); 593 un.sun_family = AF_UNIX; 594 strlcpy(un.sun_path, dev->path, sizeof(un.sun_path)); 595 596 if (dev->is_server) { 597 dev->listenfd = fd; 598 if (virtio_user_start_server(dev, &un) < 0) { 599 PMD_DRV_LOG(ERR, "virtio-user startup fails in server mode"); 600 close(fd); 601 return -1; 602 } 603 dev->vhostfd = -1; 604 } else { 605 if (connect(fd, (struct sockaddr *)&un, sizeof(un)) < 0) { 606 PMD_DRV_LOG(ERR, "connect error, %s", strerror(errno)); 607 close(fd); 608 return -1; 609 } 610 dev->vhostfd = fd; 611 } 612 613 return 0; 614 } 615 616 static int 617 vhost_user_enable_queue_pair(struct virtio_user_dev *dev, 618 uint16_t pair_idx, 619 int enable) 620 { 621 int i; 622 623 if (dev->qp_enabled[pair_idx] == enable) 624 return 0; 625 626 for (i = 0; i < 2; ++i) { 627 struct vhost_vring_state state = { 628 .index = pair_idx * 2 + i, 629 .num = enable, 630 }; 631 632 if (vhost_user_sock(dev, VHOST_USER_SET_VRING_ENABLE, &state)) 633 return -1; 634 } 635 636 dev->qp_enabled[pair_idx] = enable; 637 return 0; 638 } 639 640 struct virtio_user_backend_ops virtio_ops_user = { 641 .setup = vhost_user_setup, 642 .set_owner = vhost_user_set_owner, 643 .get_features = vhost_user_get_features, 644 .set_features = vhost_user_set_features, 645 .get_protocol_features = vhost_user_get_protocol_features, 646 .set_protocol_features = vhost_user_set_protocol_features, 647 .send_request = vhost_user_sock, 648 .enable_qp = vhost_user_enable_queue_pair 649 }; 650