1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright(c) 2010-2016 Intel Corporation 3 */ 4 5 #include <sys/socket.h> 6 #include <sys/types.h> 7 #include <sys/stat.h> 8 #include <unistd.h> 9 #include <fcntl.h> 10 #include <sys/un.h> 11 #include <string.h> 12 #include <errno.h> 13 14 #include <rte_string_fns.h> 15 #include <rte_fbarray.h> 16 17 #include "vhost.h" 18 #include "virtio_user_dev.h" 19 20 /* The version of the protocol we support */ 21 #define VHOST_USER_VERSION 0x1 22 23 #define VHOST_MEMORY_MAX_NREGIONS 8 24 struct vhost_memory { 25 uint32_t nregions; 26 uint32_t padding; 27 struct vhost_memory_region regions[VHOST_MEMORY_MAX_NREGIONS]; 28 }; 29 30 struct vhost_user_msg { 31 enum vhost_user_request request; 32 33 #define VHOST_USER_VERSION_MASK 0x3 34 #define VHOST_USER_REPLY_MASK (0x1 << 2) 35 #define VHOST_USER_NEED_REPLY_MASK (0x1 << 3) 36 uint32_t flags; 37 uint32_t size; /* the following payload size */ 38 union { 39 #define VHOST_USER_VRING_IDX_MASK 0xff 40 #define VHOST_USER_VRING_NOFD_MASK (0x1 << 8) 41 uint64_t u64; 42 struct vhost_vring_state state; 43 struct vhost_vring_addr addr; 44 struct vhost_memory memory; 45 } payload; 46 int fds[VHOST_MEMORY_MAX_NREGIONS]; 47 } __rte_packed; 48 49 #define VHOST_USER_HDR_SIZE offsetof(struct vhost_user_msg, payload.u64) 50 #define VHOST_USER_PAYLOAD_SIZE \ 51 (sizeof(struct vhost_user_msg) - VHOST_USER_HDR_SIZE) 52 53 static int 54 vhost_user_write(int fd, void *buf, int len, int *fds, int fd_num) 55 { 56 int r; 57 struct msghdr msgh; 58 struct iovec iov; 59 size_t fd_size = fd_num * sizeof(int); 60 char control[CMSG_SPACE(fd_size)]; 61 struct cmsghdr *cmsg; 62 63 memset(&msgh, 0, sizeof(msgh)); 64 memset(control, 0, sizeof(control)); 65 66 iov.iov_base = (uint8_t *)buf; 67 iov.iov_len = len; 68 69 msgh.msg_iov = &iov; 70 msgh.msg_iovlen = 1; 71 msgh.msg_control = control; 72 msgh.msg_controllen = sizeof(control); 73 74 cmsg = CMSG_FIRSTHDR(&msgh); 75 cmsg->cmsg_len = CMSG_LEN(fd_size); 76 cmsg->cmsg_level = SOL_SOCKET; 77 cmsg->cmsg_type = SCM_RIGHTS; 78 memcpy(CMSG_DATA(cmsg), fds, fd_size); 79 80 do { 81 r = sendmsg(fd, &msgh, 0); 82 } while (r < 0 && errno == EINTR); 83 84 return r; 85 } 86 87 static int 88 vhost_user_read(int fd, struct vhost_user_msg *msg) 89 { 90 uint32_t valid_flags = VHOST_USER_REPLY_MASK | VHOST_USER_VERSION; 91 int ret, sz_hdr = VHOST_USER_HDR_SIZE, sz_payload; 92 93 ret = recv(fd, (void *)msg, sz_hdr, 0); 94 if (ret < sz_hdr) { 95 PMD_DRV_LOG(ERR, "Failed to recv msg hdr: %d instead of %d.", 96 ret, sz_hdr); 97 goto fail; 98 } 99 100 /* validate msg flags */ 101 if (msg->flags != (valid_flags)) { 102 PMD_DRV_LOG(ERR, "Failed to recv msg: flags %x instead of %x.", 103 msg->flags, valid_flags); 104 goto fail; 105 } 106 107 sz_payload = msg->size; 108 109 if ((size_t)sz_payload > sizeof(msg->payload)) 110 goto fail; 111 112 if (sz_payload) { 113 ret = recv(fd, (void *)((char *)msg + sz_hdr), sz_payload, 0); 114 if (ret < sz_payload) { 115 PMD_DRV_LOG(ERR, 116 "Failed to recv msg payload: %d instead of %d.", 117 ret, msg->size); 118 goto fail; 119 } 120 } 121 122 return 0; 123 124 fail: 125 return -1; 126 } 127 128 struct walk_arg { 129 struct vhost_memory *vm; 130 int *fds; 131 int region_nr; 132 }; 133 134 static int 135 update_memory_region(const struct rte_memseg_list *msl __rte_unused, 136 const struct rte_memseg *ms, void *arg) 137 { 138 struct walk_arg *wa = arg; 139 struct vhost_memory_region *mr; 140 uint64_t start_addr, end_addr; 141 size_t offset; 142 int i, fd; 143 144 fd = rte_memseg_get_fd_thread_unsafe(ms); 145 if (fd < 0) { 146 PMD_DRV_LOG(ERR, "Failed to get fd, ms=%p rte_errno=%d", 147 ms, rte_errno); 148 return -1; 149 } 150 151 if (rte_memseg_get_fd_offset_thread_unsafe(ms, &offset) < 0) { 152 PMD_DRV_LOG(ERR, "Failed to get offset, ms=%p rte_errno=%d", 153 ms, rte_errno); 154 return -1; 155 } 156 157 start_addr = (uint64_t)(uintptr_t)ms->addr; 158 end_addr = start_addr + ms->len; 159 160 for (i = 0; i < wa->region_nr; i++) { 161 if (wa->fds[i] != fd) 162 continue; 163 164 mr = &wa->vm->regions[i]; 165 166 if (mr->userspace_addr + mr->memory_size < end_addr) 167 mr->memory_size = end_addr - mr->userspace_addr; 168 169 if (mr->userspace_addr > start_addr) { 170 mr->userspace_addr = start_addr; 171 mr->guest_phys_addr = start_addr; 172 } 173 174 if (mr->mmap_offset > offset) 175 mr->mmap_offset = offset; 176 177 PMD_DRV_LOG(DEBUG, "index=%d fd=%d offset=0x%" PRIx64 178 " addr=0x%" PRIx64 " len=%" PRIu64, i, fd, 179 mr->mmap_offset, mr->userspace_addr, 180 mr->memory_size); 181 182 return 0; 183 } 184 185 if (i >= VHOST_MEMORY_MAX_NREGIONS) { 186 PMD_DRV_LOG(ERR, "Too many memory regions"); 187 return -1; 188 } 189 190 mr = &wa->vm->regions[i]; 191 wa->fds[i] = fd; 192 193 mr->guest_phys_addr = start_addr; 194 mr->userspace_addr = start_addr; 195 mr->memory_size = ms->len; 196 mr->mmap_offset = offset; 197 198 PMD_DRV_LOG(DEBUG, "index=%d fd=%d offset=0x%" PRIx64 199 " addr=0x%" PRIx64 " len=%" PRIu64, i, fd, 200 mr->mmap_offset, mr->userspace_addr, 201 mr->memory_size); 202 203 wa->region_nr++; 204 205 return 0; 206 } 207 208 static int 209 prepare_vhost_memory_user(struct vhost_user_msg *msg, int fds[]) 210 { 211 struct walk_arg wa; 212 213 wa.region_nr = 0; 214 wa.vm = &msg->payload.memory; 215 wa.fds = fds; 216 217 /* 218 * The memory lock has already been taken by memory subsystem 219 * or virtio_user_start_device(). 220 */ 221 if (rte_memseg_walk_thread_unsafe(update_memory_region, &wa) < 0) 222 return -1; 223 224 msg->payload.memory.nregions = wa.region_nr; 225 msg->payload.memory.padding = 0; 226 227 return 0; 228 } 229 230 static struct vhost_user_msg m; 231 232 const char * const vhost_msg_strings[] = { 233 [VHOST_USER_SET_OWNER] = "VHOST_SET_OWNER", 234 [VHOST_USER_RESET_OWNER] = "VHOST_RESET_OWNER", 235 [VHOST_USER_SET_FEATURES] = "VHOST_SET_FEATURES", 236 [VHOST_USER_GET_FEATURES] = "VHOST_GET_FEATURES", 237 [VHOST_USER_SET_VRING_CALL] = "VHOST_SET_VRING_CALL", 238 [VHOST_USER_SET_VRING_NUM] = "VHOST_SET_VRING_NUM", 239 [VHOST_USER_SET_VRING_BASE] = "VHOST_SET_VRING_BASE", 240 [VHOST_USER_GET_VRING_BASE] = "VHOST_GET_VRING_BASE", 241 [VHOST_USER_SET_VRING_ADDR] = "VHOST_SET_VRING_ADDR", 242 [VHOST_USER_SET_VRING_KICK] = "VHOST_SET_VRING_KICK", 243 [VHOST_USER_SET_MEM_TABLE] = "VHOST_SET_MEM_TABLE", 244 [VHOST_USER_SET_VRING_ENABLE] = "VHOST_SET_VRING_ENABLE", 245 [VHOST_USER_GET_PROTOCOL_FEATURES] = "VHOST_USER_GET_PROTOCOL_FEATURES", 246 [VHOST_USER_SET_PROTOCOL_FEATURES] = "VHOST_USER_SET_PROTOCOL_FEATURES", 247 }; 248 249 static int 250 vhost_user_sock(struct virtio_user_dev *dev, 251 enum vhost_user_request req, 252 void *arg) 253 { 254 struct vhost_user_msg msg; 255 struct vhost_vring_file *file = 0; 256 int need_reply = 0; 257 int has_reply_ack; 258 int fds[VHOST_MEMORY_MAX_NREGIONS]; 259 int fd_num = 0; 260 int len; 261 int vhostfd = dev->vhostfd; 262 263 RTE_SET_USED(m); 264 265 PMD_DRV_LOG(INFO, "%s", vhost_msg_strings[req]); 266 267 if (dev->is_server && vhostfd < 0) 268 return -1; 269 270 if (dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK)) 271 has_reply_ack = 1; 272 273 msg.request = req; 274 msg.flags = VHOST_USER_VERSION; 275 msg.size = 0; 276 277 switch (req) { 278 case VHOST_USER_GET_FEATURES: 279 case VHOST_USER_GET_PROTOCOL_FEATURES: 280 need_reply = 1; 281 break; 282 283 case VHOST_USER_SET_FEATURES: 284 case VHOST_USER_SET_PROTOCOL_FEATURES: 285 case VHOST_USER_SET_LOG_BASE: 286 msg.payload.u64 = *((__u64 *)arg); 287 msg.size = sizeof(m.payload.u64); 288 break; 289 290 case VHOST_USER_SET_OWNER: 291 case VHOST_USER_RESET_OWNER: 292 break; 293 294 case VHOST_USER_SET_MEM_TABLE: 295 if (prepare_vhost_memory_user(&msg, fds) < 0) 296 return -1; 297 fd_num = msg.payload.memory.nregions; 298 msg.size = sizeof(m.payload.memory.nregions); 299 msg.size += sizeof(m.payload.memory.padding); 300 msg.size += fd_num * sizeof(struct vhost_memory_region); 301 302 if (has_reply_ack) 303 msg.flags |= VHOST_USER_NEED_REPLY_MASK; 304 break; 305 306 case VHOST_USER_SET_LOG_FD: 307 fds[fd_num++] = *((int *)arg); 308 break; 309 310 case VHOST_USER_SET_VRING_NUM: 311 case VHOST_USER_SET_VRING_BASE: 312 case VHOST_USER_SET_VRING_ENABLE: 313 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state)); 314 msg.size = sizeof(m.payload.state); 315 break; 316 317 case VHOST_USER_GET_VRING_BASE: 318 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state)); 319 msg.size = sizeof(m.payload.state); 320 need_reply = 1; 321 break; 322 323 case VHOST_USER_SET_VRING_ADDR: 324 memcpy(&msg.payload.addr, arg, sizeof(msg.payload.addr)); 325 msg.size = sizeof(m.payload.addr); 326 break; 327 328 case VHOST_USER_SET_VRING_KICK: 329 case VHOST_USER_SET_VRING_CALL: 330 case VHOST_USER_SET_VRING_ERR: 331 file = arg; 332 msg.payload.u64 = file->index & VHOST_USER_VRING_IDX_MASK; 333 msg.size = sizeof(m.payload.u64); 334 if (file->fd > 0) 335 fds[fd_num++] = file->fd; 336 else 337 msg.payload.u64 |= VHOST_USER_VRING_NOFD_MASK; 338 break; 339 340 default: 341 PMD_DRV_LOG(ERR, "trying to send unhandled msg type"); 342 return -1; 343 } 344 345 len = VHOST_USER_HDR_SIZE + msg.size; 346 if (vhost_user_write(vhostfd, &msg, len, fds, fd_num) < 0) { 347 PMD_DRV_LOG(ERR, "%s failed: %s", 348 vhost_msg_strings[req], strerror(errno)); 349 return -1; 350 } 351 352 if (need_reply || msg.flags & VHOST_USER_NEED_REPLY_MASK) { 353 if (vhost_user_read(vhostfd, &msg) < 0) { 354 PMD_DRV_LOG(ERR, "Received msg failed: %s", 355 strerror(errno)); 356 return -1; 357 } 358 359 if (req != msg.request) { 360 PMD_DRV_LOG(ERR, "Received unexpected msg type"); 361 return -1; 362 } 363 364 switch (req) { 365 case VHOST_USER_GET_FEATURES: 366 case VHOST_USER_GET_PROTOCOL_FEATURES: 367 if (msg.size != sizeof(m.payload.u64)) { 368 PMD_DRV_LOG(ERR, "Received bad msg size"); 369 return -1; 370 } 371 *((__u64 *)arg) = msg.payload.u64; 372 break; 373 case VHOST_USER_GET_VRING_BASE: 374 if (msg.size != sizeof(m.payload.state)) { 375 PMD_DRV_LOG(ERR, "Received bad msg size"); 376 return -1; 377 } 378 memcpy(arg, &msg.payload.state, 379 sizeof(struct vhost_vring_state)); 380 break; 381 default: 382 /* Reply-ack handling */ 383 if (msg.size != sizeof(m.payload.u64)) { 384 PMD_DRV_LOG(ERR, "Received bad msg size"); 385 return -1; 386 } 387 388 if (msg.payload.u64 != 0) { 389 PMD_DRV_LOG(ERR, "Slave replied NACK"); 390 return -1; 391 } 392 393 break; 394 } 395 } 396 397 return 0; 398 } 399 400 #define MAX_VIRTIO_USER_BACKLOG 1 401 static int 402 virtio_user_start_server(struct virtio_user_dev *dev, struct sockaddr_un *un) 403 { 404 int ret; 405 int flag; 406 int fd = dev->listenfd; 407 408 ret = bind(fd, (struct sockaddr *)un, sizeof(*un)); 409 if (ret < 0) { 410 PMD_DRV_LOG(ERR, "failed to bind to %s: %s; remove it and try again\n", 411 dev->path, strerror(errno)); 412 return -1; 413 } 414 ret = listen(fd, MAX_VIRTIO_USER_BACKLOG); 415 if (ret < 0) 416 return -1; 417 418 flag = fcntl(fd, F_GETFL); 419 if (fcntl(fd, F_SETFL, flag | O_NONBLOCK) < 0) { 420 PMD_DRV_LOG(ERR, "fcntl failed, %s", strerror(errno)); 421 return -1; 422 } 423 424 return 0; 425 } 426 427 /** 428 * Set up environment to talk with a vhost user backend. 429 * 430 * @return 431 * - (-1) if fail; 432 * - (0) if succeed. 433 */ 434 static int 435 vhost_user_setup(struct virtio_user_dev *dev) 436 { 437 int fd; 438 int flag; 439 struct sockaddr_un un; 440 441 fd = socket(AF_UNIX, SOCK_STREAM, 0); 442 if (fd < 0) { 443 PMD_DRV_LOG(ERR, "socket() error, %s", strerror(errno)); 444 return -1; 445 } 446 447 flag = fcntl(fd, F_GETFD); 448 if (fcntl(fd, F_SETFD, flag | FD_CLOEXEC) < 0) 449 PMD_DRV_LOG(WARNING, "fcntl failed, %s", strerror(errno)); 450 451 memset(&un, 0, sizeof(un)); 452 un.sun_family = AF_UNIX; 453 strlcpy(un.sun_path, dev->path, sizeof(un.sun_path)); 454 455 if (dev->is_server) { 456 dev->listenfd = fd; 457 if (virtio_user_start_server(dev, &un) < 0) { 458 PMD_DRV_LOG(ERR, "virtio-user startup fails in server mode"); 459 close(fd); 460 return -1; 461 } 462 dev->vhostfd = -1; 463 } else { 464 if (connect(fd, (struct sockaddr *)&un, sizeof(un)) < 0) { 465 PMD_DRV_LOG(ERR, "connect error, %s", strerror(errno)); 466 close(fd); 467 return -1; 468 } 469 dev->vhostfd = fd; 470 } 471 472 return 0; 473 } 474 475 static int 476 vhost_user_enable_queue_pair(struct virtio_user_dev *dev, 477 uint16_t pair_idx, 478 int enable) 479 { 480 int i; 481 482 if (dev->qp_enabled[pair_idx] == enable) 483 return 0; 484 485 for (i = 0; i < 2; ++i) { 486 struct vhost_vring_state state = { 487 .index = pair_idx * 2 + i, 488 .num = enable, 489 }; 490 491 if (vhost_user_sock(dev, VHOST_USER_SET_VRING_ENABLE, &state)) 492 return -1; 493 } 494 495 dev->qp_enabled[pair_idx] = enable; 496 return 0; 497 } 498 499 struct virtio_user_backend_ops virtio_ops_user = { 500 .setup = vhost_user_setup, 501 .send_request = vhost_user_sock, 502 .enable_qp = vhost_user_enable_queue_pair 503 }; 504