1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright(c) 2010-2016 Intel Corporation 3 */ 4 5 #include <sys/socket.h> 6 #include <sys/types.h> 7 #include <sys/stat.h> 8 #include <unistd.h> 9 #include <fcntl.h> 10 #include <sys/un.h> 11 #include <string.h> 12 #include <errno.h> 13 14 #include <rte_string_fns.h> 15 #include <rte_fbarray.h> 16 17 #include "vhost.h" 18 #include "virtio_user_dev.h" 19 20 /* The version of the protocol we support */ 21 #define VHOST_USER_VERSION 0x1 22 23 #define VHOST_MEMORY_MAX_NREGIONS 8 24 struct vhost_memory { 25 uint32_t nregions; 26 uint32_t padding; 27 struct vhost_memory_region regions[VHOST_MEMORY_MAX_NREGIONS]; 28 }; 29 30 struct vhost_user_msg { 31 enum vhost_user_request request; 32 33 #define VHOST_USER_VERSION_MASK 0x3 34 #define VHOST_USER_REPLY_MASK (0x1 << 2) 35 #define VHOST_USER_NEED_REPLY_MASK (0x1 << 3) 36 uint32_t flags; 37 uint32_t size; /* the following payload size */ 38 union { 39 #define VHOST_USER_VRING_IDX_MASK 0xff 40 #define VHOST_USER_VRING_NOFD_MASK (0x1 << 8) 41 uint64_t u64; 42 struct vhost_vring_state state; 43 struct vhost_vring_addr addr; 44 struct vhost_memory memory; 45 } payload; 46 int fds[VHOST_MEMORY_MAX_NREGIONS]; 47 } __rte_packed; 48 49 #define VHOST_USER_HDR_SIZE offsetof(struct vhost_user_msg, payload.u64) 50 #define VHOST_USER_PAYLOAD_SIZE \ 51 (sizeof(struct vhost_user_msg) - VHOST_USER_HDR_SIZE) 52 53 static int 54 vhost_user_write(int fd, void *buf, int len, int *fds, int fd_num) 55 { 56 int r; 57 struct msghdr msgh; 58 struct iovec iov; 59 size_t fd_size = fd_num * sizeof(int); 60 char control[CMSG_SPACE(fd_size)]; 61 struct cmsghdr *cmsg; 62 63 memset(&msgh, 0, sizeof(msgh)); 64 memset(control, 0, sizeof(control)); 65 66 iov.iov_base = (uint8_t *)buf; 67 iov.iov_len = len; 68 69 msgh.msg_iov = &iov; 70 msgh.msg_iovlen = 1; 71 msgh.msg_control = control; 72 msgh.msg_controllen = sizeof(control); 73 74 cmsg = CMSG_FIRSTHDR(&msgh); 75 cmsg->cmsg_len = CMSG_LEN(fd_size); 76 cmsg->cmsg_level = SOL_SOCKET; 77 cmsg->cmsg_type = SCM_RIGHTS; 78 memcpy(CMSG_DATA(cmsg), fds, fd_size); 79 80 do { 81 r = sendmsg(fd, &msgh, 0); 82 } while (r < 0 && errno == EINTR); 83 84 return r; 85 } 86 87 static int 88 vhost_user_read(int fd, struct vhost_user_msg *msg) 89 { 90 uint32_t valid_flags = VHOST_USER_REPLY_MASK | VHOST_USER_VERSION; 91 int ret, sz_hdr = VHOST_USER_HDR_SIZE, sz_payload; 92 93 ret = recv(fd, (void *)msg, sz_hdr, 0); 94 if (ret < sz_hdr) { 95 PMD_DRV_LOG(ERR, "Failed to recv msg hdr: %d instead of %d.", 96 ret, sz_hdr); 97 goto fail; 98 } 99 100 /* validate msg flags */ 101 if (msg->flags != (valid_flags)) { 102 PMD_DRV_LOG(ERR, "Failed to recv msg: flags %x instead of %x.", 103 msg->flags, valid_flags); 104 goto fail; 105 } 106 107 sz_payload = msg->size; 108 109 if ((size_t)sz_payload > sizeof(msg->payload)) 110 goto fail; 111 112 if (sz_payload) { 113 ret = recv(fd, (void *)((char *)msg + sz_hdr), sz_payload, 0); 114 if (ret < sz_payload) { 115 PMD_DRV_LOG(ERR, 116 "Failed to recv msg payload: %d instead of %d.", 117 ret, msg->size); 118 goto fail; 119 } 120 } 121 122 return 0; 123 124 fail: 125 return -1; 126 } 127 128 struct walk_arg { 129 struct vhost_memory *vm; 130 int *fds; 131 int region_nr; 132 }; 133 134 static int 135 update_memory_region(const struct rte_memseg_list *msl __rte_unused, 136 const struct rte_memseg *ms, void *arg) 137 { 138 struct walk_arg *wa = arg; 139 struct vhost_memory_region *mr; 140 uint64_t start_addr, end_addr; 141 size_t offset; 142 int i, fd; 143 144 fd = rte_memseg_get_fd_thread_unsafe(ms); 145 if (fd < 0) { 146 PMD_DRV_LOG(ERR, "Failed to get fd, ms=%p rte_errno=%d", 147 ms, rte_errno); 148 return -1; 149 } 150 151 if (rte_memseg_get_fd_offset_thread_unsafe(ms, &offset) < 0) { 152 PMD_DRV_LOG(ERR, "Failed to get offset, ms=%p rte_errno=%d", 153 ms, rte_errno); 154 return -1; 155 } 156 157 start_addr = (uint64_t)(uintptr_t)ms->addr; 158 end_addr = start_addr + ms->len; 159 160 for (i = 0; i < wa->region_nr; i++) { 161 if (wa->fds[i] != fd) 162 continue; 163 164 mr = &wa->vm->regions[i]; 165 166 if (mr->userspace_addr + mr->memory_size < end_addr) 167 mr->memory_size = end_addr - mr->userspace_addr; 168 169 if (mr->userspace_addr > start_addr) { 170 mr->userspace_addr = start_addr; 171 mr->guest_phys_addr = start_addr; 172 } 173 174 if (mr->mmap_offset > offset) 175 mr->mmap_offset = offset; 176 177 PMD_DRV_LOG(DEBUG, "index=%d fd=%d offset=0x%" PRIx64 178 " addr=0x%" PRIx64 " len=%" PRIu64, i, fd, 179 mr->mmap_offset, mr->userspace_addr, 180 mr->memory_size); 181 182 return 0; 183 } 184 185 if (i >= VHOST_MEMORY_MAX_NREGIONS) { 186 PMD_DRV_LOG(ERR, "Too many memory regions"); 187 return -1; 188 } 189 190 mr = &wa->vm->regions[i]; 191 wa->fds[i] = fd; 192 193 mr->guest_phys_addr = start_addr; 194 mr->userspace_addr = start_addr; 195 mr->memory_size = ms->len; 196 mr->mmap_offset = offset; 197 198 PMD_DRV_LOG(DEBUG, "index=%d fd=%d offset=0x%" PRIx64 199 " addr=0x%" PRIx64 " len=%" PRIu64, i, fd, 200 mr->mmap_offset, mr->userspace_addr, 201 mr->memory_size); 202 203 wa->region_nr++; 204 205 return 0; 206 } 207 208 static int 209 prepare_vhost_memory_user(struct vhost_user_msg *msg, int fds[]) 210 { 211 struct walk_arg wa; 212 213 wa.region_nr = 0; 214 wa.vm = &msg->payload.memory; 215 wa.fds = fds; 216 217 /* 218 * The memory lock has already been taken by memory subsystem 219 * or virtio_user_start_device(). 220 */ 221 if (rte_memseg_walk_thread_unsafe(update_memory_region, &wa) < 0) 222 return -1; 223 224 msg->payload.memory.nregions = wa.region_nr; 225 msg->payload.memory.padding = 0; 226 227 return 0; 228 } 229 230 static struct vhost_user_msg m; 231 232 const char * const vhost_msg_strings[] = { 233 [VHOST_USER_SET_OWNER] = "VHOST_SET_OWNER", 234 [VHOST_USER_RESET_OWNER] = "VHOST_RESET_OWNER", 235 [VHOST_USER_SET_FEATURES] = "VHOST_SET_FEATURES", 236 [VHOST_USER_GET_FEATURES] = "VHOST_GET_FEATURES", 237 [VHOST_USER_SET_VRING_CALL] = "VHOST_SET_VRING_CALL", 238 [VHOST_USER_SET_VRING_NUM] = "VHOST_SET_VRING_NUM", 239 [VHOST_USER_SET_VRING_BASE] = "VHOST_SET_VRING_BASE", 240 [VHOST_USER_GET_VRING_BASE] = "VHOST_GET_VRING_BASE", 241 [VHOST_USER_SET_VRING_ADDR] = "VHOST_SET_VRING_ADDR", 242 [VHOST_USER_SET_VRING_KICK] = "VHOST_SET_VRING_KICK", 243 [VHOST_USER_SET_MEM_TABLE] = "VHOST_SET_MEM_TABLE", 244 [VHOST_USER_SET_VRING_ENABLE] = "VHOST_SET_VRING_ENABLE", 245 [VHOST_USER_GET_PROTOCOL_FEATURES] = "VHOST_USER_GET_PROTOCOL_FEATURES", 246 [VHOST_USER_SET_PROTOCOL_FEATURES] = "VHOST_USER_SET_PROTOCOL_FEATURES", 247 [VHOST_USER_SET_STATUS] = "VHOST_SET_STATUS", 248 [VHOST_USER_GET_STATUS] = "VHOST_GET_STATUS", 249 }; 250 251 static int 252 vhost_user_sock(struct virtio_user_dev *dev, 253 enum vhost_user_request req, 254 void *arg) 255 { 256 struct vhost_user_msg msg; 257 struct vhost_vring_file *file = 0; 258 int need_reply = 0; 259 int has_reply_ack = 0; 260 int fds[VHOST_MEMORY_MAX_NREGIONS]; 261 int fd_num = 0; 262 int len; 263 int vhostfd = dev->vhostfd; 264 265 RTE_SET_USED(m); 266 267 PMD_DRV_LOG(INFO, "%s", vhost_msg_strings[req]); 268 269 if (dev->is_server && vhostfd < 0) 270 return -1; 271 272 if (dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK)) 273 has_reply_ack = 1; 274 275 msg.request = req; 276 msg.flags = VHOST_USER_VERSION; 277 msg.size = 0; 278 279 switch (req) { 280 case VHOST_USER_GET_STATUS: 281 if (!(dev->protocol_features & 282 (1ULL << VHOST_USER_PROTOCOL_F_STATUS))) 283 return 0; 284 /* Fallthrough */ 285 case VHOST_USER_GET_FEATURES: 286 case VHOST_USER_GET_PROTOCOL_FEATURES: 287 need_reply = 1; 288 break; 289 290 case VHOST_USER_SET_STATUS: 291 if (!(dev->protocol_features & 292 (1ULL << VHOST_USER_PROTOCOL_F_STATUS))) 293 return 0; 294 295 if (has_reply_ack) 296 msg.flags |= VHOST_USER_NEED_REPLY_MASK; 297 /* Fallthrough */ 298 case VHOST_USER_SET_FEATURES: 299 case VHOST_USER_SET_PROTOCOL_FEATURES: 300 case VHOST_USER_SET_LOG_BASE: 301 msg.payload.u64 = *((__u64 *)arg); 302 msg.size = sizeof(m.payload.u64); 303 break; 304 305 case VHOST_USER_SET_OWNER: 306 case VHOST_USER_RESET_OWNER: 307 break; 308 309 case VHOST_USER_SET_MEM_TABLE: 310 if (prepare_vhost_memory_user(&msg, fds) < 0) 311 return -1; 312 fd_num = msg.payload.memory.nregions; 313 msg.size = sizeof(m.payload.memory.nregions); 314 msg.size += sizeof(m.payload.memory.padding); 315 msg.size += fd_num * sizeof(struct vhost_memory_region); 316 317 if (has_reply_ack) 318 msg.flags |= VHOST_USER_NEED_REPLY_MASK; 319 break; 320 321 case VHOST_USER_SET_LOG_FD: 322 fds[fd_num++] = *((int *)arg); 323 break; 324 325 case VHOST_USER_SET_VRING_NUM: 326 case VHOST_USER_SET_VRING_BASE: 327 case VHOST_USER_SET_VRING_ENABLE: 328 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state)); 329 msg.size = sizeof(m.payload.state); 330 break; 331 332 case VHOST_USER_GET_VRING_BASE: 333 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state)); 334 msg.size = sizeof(m.payload.state); 335 need_reply = 1; 336 break; 337 338 case VHOST_USER_SET_VRING_ADDR: 339 memcpy(&msg.payload.addr, arg, sizeof(msg.payload.addr)); 340 msg.size = sizeof(m.payload.addr); 341 break; 342 343 case VHOST_USER_SET_VRING_KICK: 344 case VHOST_USER_SET_VRING_CALL: 345 case VHOST_USER_SET_VRING_ERR: 346 file = arg; 347 msg.payload.u64 = file->index & VHOST_USER_VRING_IDX_MASK; 348 msg.size = sizeof(m.payload.u64); 349 if (file->fd > 0) 350 fds[fd_num++] = file->fd; 351 else 352 msg.payload.u64 |= VHOST_USER_VRING_NOFD_MASK; 353 break; 354 355 default: 356 PMD_DRV_LOG(ERR, "trying to send unhandled msg type"); 357 return -1; 358 } 359 360 len = VHOST_USER_HDR_SIZE + msg.size; 361 if (vhost_user_write(vhostfd, &msg, len, fds, fd_num) < 0) { 362 PMD_DRV_LOG(ERR, "%s failed: %s", 363 vhost_msg_strings[req], strerror(errno)); 364 return -1; 365 } 366 367 if (need_reply || msg.flags & VHOST_USER_NEED_REPLY_MASK) { 368 if (vhost_user_read(vhostfd, &msg) < 0) { 369 PMD_DRV_LOG(ERR, "Received msg failed: %s", 370 strerror(errno)); 371 return -1; 372 } 373 374 if (req != msg.request) { 375 PMD_DRV_LOG(ERR, "Received unexpected msg type"); 376 return -1; 377 } 378 379 switch (req) { 380 case VHOST_USER_GET_FEATURES: 381 case VHOST_USER_GET_STATUS: 382 case VHOST_USER_GET_PROTOCOL_FEATURES: 383 if (msg.size != sizeof(m.payload.u64)) { 384 PMD_DRV_LOG(ERR, "Received bad msg size"); 385 return -1; 386 } 387 *((__u64 *)arg) = msg.payload.u64; 388 break; 389 case VHOST_USER_GET_VRING_BASE: 390 if (msg.size != sizeof(m.payload.state)) { 391 PMD_DRV_LOG(ERR, "Received bad msg size"); 392 return -1; 393 } 394 memcpy(arg, &msg.payload.state, 395 sizeof(struct vhost_vring_state)); 396 break; 397 default: 398 /* Reply-ack handling */ 399 if (msg.size != sizeof(m.payload.u64)) { 400 PMD_DRV_LOG(ERR, "Received bad msg size"); 401 return -1; 402 } 403 404 if (msg.payload.u64 != 0) { 405 PMD_DRV_LOG(ERR, "Slave replied NACK"); 406 return -1; 407 } 408 409 break; 410 } 411 } 412 413 return 0; 414 } 415 416 #define MAX_VIRTIO_USER_BACKLOG 1 417 static int 418 virtio_user_start_server(struct virtio_user_dev *dev, struct sockaddr_un *un) 419 { 420 int ret; 421 int flag; 422 int fd = dev->listenfd; 423 424 ret = bind(fd, (struct sockaddr *)un, sizeof(*un)); 425 if (ret < 0) { 426 PMD_DRV_LOG(ERR, "failed to bind to %s: %s; remove it and try again\n", 427 dev->path, strerror(errno)); 428 return -1; 429 } 430 ret = listen(fd, MAX_VIRTIO_USER_BACKLOG); 431 if (ret < 0) 432 return -1; 433 434 flag = fcntl(fd, F_GETFL); 435 if (fcntl(fd, F_SETFL, flag | O_NONBLOCK) < 0) { 436 PMD_DRV_LOG(ERR, "fcntl failed, %s", strerror(errno)); 437 return -1; 438 } 439 440 return 0; 441 } 442 443 /** 444 * Set up environment to talk with a vhost user backend. 445 * 446 * @return 447 * - (-1) if fail; 448 * - (0) if succeed. 449 */ 450 static int 451 vhost_user_setup(struct virtio_user_dev *dev) 452 { 453 int fd; 454 int flag; 455 struct sockaddr_un un; 456 457 fd = socket(AF_UNIX, SOCK_STREAM, 0); 458 if (fd < 0) { 459 PMD_DRV_LOG(ERR, "socket() error, %s", strerror(errno)); 460 return -1; 461 } 462 463 flag = fcntl(fd, F_GETFD); 464 if (fcntl(fd, F_SETFD, flag | FD_CLOEXEC) < 0) 465 PMD_DRV_LOG(WARNING, "fcntl failed, %s", strerror(errno)); 466 467 memset(&un, 0, sizeof(un)); 468 un.sun_family = AF_UNIX; 469 strlcpy(un.sun_path, dev->path, sizeof(un.sun_path)); 470 471 if (dev->is_server) { 472 dev->listenfd = fd; 473 if (virtio_user_start_server(dev, &un) < 0) { 474 PMD_DRV_LOG(ERR, "virtio-user startup fails in server mode"); 475 close(fd); 476 return -1; 477 } 478 dev->vhostfd = -1; 479 } else { 480 if (connect(fd, (struct sockaddr *)&un, sizeof(un)) < 0) { 481 PMD_DRV_LOG(ERR, "connect error, %s", strerror(errno)); 482 close(fd); 483 return -1; 484 } 485 dev->vhostfd = fd; 486 } 487 488 return 0; 489 } 490 491 static int 492 vhost_user_enable_queue_pair(struct virtio_user_dev *dev, 493 uint16_t pair_idx, 494 int enable) 495 { 496 int i; 497 498 if (dev->qp_enabled[pair_idx] == enable) 499 return 0; 500 501 for (i = 0; i < 2; ++i) { 502 struct vhost_vring_state state = { 503 .index = pair_idx * 2 + i, 504 .num = enable, 505 }; 506 507 if (vhost_user_sock(dev, VHOST_USER_SET_VRING_ENABLE, &state)) 508 return -1; 509 } 510 511 dev->qp_enabled[pair_idx] = enable; 512 return 0; 513 } 514 515 struct virtio_user_backend_ops virtio_ops_user = { 516 .setup = vhost_user_setup, 517 .send_request = vhost_user_sock, 518 .enable_qp = vhost_user_enable_queue_pair 519 }; 520