| .. | .. |
|---|
| 1 | +// SPDX-License-Identifier: GPL-2.0-only |
|---|
| 1 | 2 | /* |
|---|
| 2 | 3 | * common code for virtio vsock |
|---|
| 3 | 4 | * |
|---|
| 4 | 5 | * Copyright (C) 2013-2015 Red Hat, Inc. |
|---|
| 5 | 6 | * Author: Asias He <asias@redhat.com> |
|---|
| 6 | 7 | * Stefan Hajnoczi <stefanha@redhat.com> |
|---|
| 7 | | - * |
|---|
| 8 | | - * This work is licensed under the terms of the GNU GPL, version 2. |
|---|
| 9 | 8 | */ |
|---|
| 10 | 9 | #include <linux/spinlock.h> |
|---|
| 11 | 10 | #include <linux/module.h> |
|---|
| 12 | 11 | #include <linux/sched/signal.h> |
|---|
| 13 | 12 | #include <linux/ctype.h> |
|---|
| 14 | 13 | #include <linux/list.h> |
|---|
| 15 | | -#include <linux/virtio.h> |
|---|
| 16 | | -#include <linux/virtio_ids.h> |
|---|
| 17 | | -#include <linux/virtio_config.h> |
|---|
| 18 | 14 | #include <linux/virtio_vsock.h> |
|---|
| 19 | 15 | #include <uapi/linux/vsockmon.h> |
|---|
| 20 | 16 | |
|---|
| .. | .. |
|---|
| 27 | 23 | /* How long to wait for graceful shutdown of a connection */ |
|---|
| 28 | 24 | #define VSOCK_CLOSE_TIMEOUT (8 * HZ) |
|---|
| 29 | 25 | |
|---|
| 26 | +/* Threshold for detecting small packets to copy */ |
|---|
| 27 | +#define GOOD_COPY_LEN 128 |
|---|
| 28 | + |
|---|
| 30 | 29 | uint virtio_transport_max_vsock_pkt_buf_size = 64 * 1024; |
|---|
| 31 | 30 | module_param(virtio_transport_max_vsock_pkt_buf_size, uint, 0444); |
|---|
| 32 | 31 | EXPORT_SYMBOL_GPL(virtio_transport_max_vsock_pkt_buf_size); |
|---|
| 33 | 32 | |
|---|
| 34 | | -static const struct virtio_transport *virtio_transport_get_ops(void) |
|---|
| 33 | +static const struct virtio_transport * |
|---|
| 34 | +virtio_transport_get_ops(struct vsock_sock *vsk) |
|---|
| 35 | 35 | { |
|---|
| 36 | | - const struct vsock_transport *t = vsock_core_get_transport(); |
|---|
| 36 | + const struct vsock_transport *t = vsock_core_get_transport(vsk); |
|---|
| 37 | + |
|---|
| 38 | + if (WARN_ON(!t)) |
|---|
| 39 | + return NULL; |
|---|
| 37 | 40 | |
|---|
| 38 | 41 | return container_of(t, struct virtio_transport, transport); |
|---|
| 39 | 42 | } |
|---|
| .. | .. |
|---|
| 69 | 72 | pkt->buf = kmalloc(len, GFP_KERNEL); |
|---|
| 70 | 73 | if (!pkt->buf) |
|---|
| 71 | 74 | goto out_pkt; |
|---|
| 75 | + |
|---|
| 76 | + pkt->buf_len = len; |
|---|
| 77 | + |
|---|
| 72 | 78 | err = memcpy_from_msg(pkt->buf, info->msg, len); |
|---|
| 73 | 79 | if (err) |
|---|
| 74 | 80 | goto out; |
|---|
| .. | .. |
|---|
| 155 | 161 | |
|---|
| 156 | 162 | void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt) |
|---|
| 157 | 163 | { |
|---|
| 164 | + if (pkt->tap_delivered) |
|---|
| 165 | + return; |
|---|
| 166 | + |
|---|
| 158 | 167 | vsock_deliver_tap(virtio_transport_build_skb, pkt); |
|---|
| 168 | + pkt->tap_delivered = true; |
|---|
| 159 | 169 | } |
|---|
| 160 | 170 | EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); |
|---|
| 161 | 171 | |
|---|
| 172 | +/* This function can only be used on connecting/connected sockets, |
|---|
| 173 | + * since a socket assigned to a transport is required. |
|---|
| 174 | + * |
|---|
| 175 | + * Do not use on listener sockets! |
|---|
| 176 | + */ |
|---|
| 162 | 177 | static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, |
|---|
| 163 | 178 | struct virtio_vsock_pkt_info *info) |
|---|
| 164 | 179 | { |
|---|
| 165 | 180 | u32 src_cid, src_port, dst_cid, dst_port; |
|---|
| 181 | + const struct virtio_transport *t_ops; |
|---|
| 166 | 182 | struct virtio_vsock_sock *vvs; |
|---|
| 167 | 183 | struct virtio_vsock_pkt *pkt; |
|---|
| 168 | 184 | u32 pkt_len = info->pkt_len; |
|---|
| 169 | 185 | |
|---|
| 170 | | - src_cid = vm_sockets_get_local_cid(); |
|---|
| 186 | + t_ops = virtio_transport_get_ops(vsk); |
|---|
| 187 | + if (unlikely(!t_ops)) |
|---|
| 188 | + return -EFAULT; |
|---|
| 189 | + |
|---|
| 190 | + src_cid = t_ops->transport.get_local_cid(); |
|---|
| 171 | 191 | src_port = vsk->local_addr.svm_port; |
|---|
| 172 | 192 | if (!info->remote_cid) { |
|---|
| 173 | 193 | dst_cid = vsk->remote_addr.svm_cid; |
|---|
| .. | .. |
|---|
| 180 | 200 | vvs = vsk->trans; |
|---|
| 181 | 201 | |
|---|
| 182 | 202 | /* we can send less than pkt_len bytes */ |
|---|
| 183 | | - if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) |
|---|
| 184 | | - pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; |
|---|
| 203 | + if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) |
|---|
| 204 | + pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE; |
|---|
| 185 | 205 | |
|---|
| 186 | 206 | /* virtio_transport_get_credit might return less than pkt_len credit */ |
|---|
| 187 | 207 | pkt_len = virtio_transport_get_credit(vvs, pkt_len); |
|---|
| .. | .. |
|---|
| 200 | 220 | |
|---|
| 201 | 221 | virtio_transport_inc_tx_pkt(vvs, pkt); |
|---|
| 202 | 222 | |
|---|
| 203 | | - return virtio_transport_get_ops()->send_pkt(pkt); |
|---|
| 223 | + return t_ops->send_pkt(pkt); |
|---|
| 204 | 224 | } |
|---|
| 205 | 225 | |
|---|
| 206 | | -static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, |
|---|
| 226 | +static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, |
|---|
| 207 | 227 | struct virtio_vsock_pkt *pkt) |
|---|
| 208 | 228 | { |
|---|
| 229 | + if (vvs->rx_bytes + pkt->len > vvs->buf_alloc) |
|---|
| 230 | + return false; |
|---|
| 231 | + |
|---|
| 209 | 232 | vvs->rx_bytes += pkt->len; |
|---|
| 233 | + return true; |
|---|
| 210 | 234 | } |
|---|
| 211 | 235 | |
|---|
| 212 | 236 | static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, |
|---|
| .. | .. |
|---|
| 218 | 242 | |
|---|
| 219 | 243 | void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) |
|---|
| 220 | 244 | { |
|---|
| 221 | | - spin_lock_bh(&vvs->tx_lock); |
|---|
| 245 | + spin_lock_bh(&vvs->rx_lock); |
|---|
| 246 | + vvs->last_fwd_cnt = vvs->fwd_cnt; |
|---|
| 222 | 247 | pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); |
|---|
| 223 | 248 | pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); |
|---|
| 224 | | - spin_unlock_bh(&vvs->tx_lock); |
|---|
| 249 | + spin_unlock_bh(&vvs->rx_lock); |
|---|
| 225 | 250 | } |
|---|
| 226 | 251 | EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); |
|---|
| 227 | 252 | |
|---|
| .. | .. |
|---|
| 262 | 287 | } |
|---|
| 263 | 288 | |
|---|
| 264 | 289 | static ssize_t |
|---|
| 290 | +virtio_transport_stream_do_peek(struct vsock_sock *vsk, |
|---|
| 291 | + struct msghdr *msg, |
|---|
| 292 | + size_t len) |
|---|
| 293 | +{ |
|---|
| 294 | + struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 295 | + struct virtio_vsock_pkt *pkt; |
|---|
| 296 | + size_t bytes, total = 0, off; |
|---|
| 297 | + int err = -EFAULT; |
|---|
| 298 | + |
|---|
| 299 | + spin_lock_bh(&vvs->rx_lock); |
|---|
| 300 | + |
|---|
| 301 | + list_for_each_entry(pkt, &vvs->rx_queue, list) { |
|---|
| 302 | + off = pkt->off; |
|---|
| 303 | + |
|---|
| 304 | + if (total == len) |
|---|
| 305 | + break; |
|---|
| 306 | + |
|---|
| 307 | + while (total < len && off < pkt->len) { |
|---|
| 308 | + bytes = len - total; |
|---|
| 309 | + if (bytes > pkt->len - off) |
|---|
| 310 | + bytes = pkt->len - off; |
|---|
| 311 | + |
|---|
| 312 | + /* sk_lock is held by caller so no one else can dequeue. |
|---|
| 313 | + * Unlock rx_lock since memcpy_to_msg() may sleep. |
|---|
| 314 | + */ |
|---|
| 315 | + spin_unlock_bh(&vvs->rx_lock); |
|---|
| 316 | + |
|---|
| 317 | + err = memcpy_to_msg(msg, pkt->buf + off, bytes); |
|---|
| 318 | + if (err) |
|---|
| 319 | + goto out; |
|---|
| 320 | + |
|---|
| 321 | + spin_lock_bh(&vvs->rx_lock); |
|---|
| 322 | + |
|---|
| 323 | + total += bytes; |
|---|
| 324 | + off += bytes; |
|---|
| 325 | + } |
|---|
| 326 | + } |
|---|
| 327 | + |
|---|
| 328 | + spin_unlock_bh(&vvs->rx_lock); |
|---|
| 329 | + |
|---|
| 330 | + return total; |
|---|
| 331 | + |
|---|
| 332 | +out: |
|---|
| 333 | + if (total) |
|---|
| 334 | + err = total; |
|---|
| 335 | + return err; |
|---|
| 336 | +} |
|---|
| 337 | + |
|---|
| 338 | +static ssize_t |
|---|
| 265 | 339 | virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, |
|---|
| 266 | 340 | struct msghdr *msg, |
|---|
| 267 | 341 | size_t len) |
|---|
| .. | .. |
|---|
| 269 | 343 | struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 270 | 344 | struct virtio_vsock_pkt *pkt; |
|---|
| 271 | 345 | size_t bytes, total = 0; |
|---|
| 346 | + u32 free_space; |
|---|
| 272 | 347 | int err = -EFAULT; |
|---|
| 273 | 348 | |
|---|
| 274 | 349 | spin_lock_bh(&vvs->rx_lock); |
|---|
| .. | .. |
|---|
| 299 | 374 | virtio_transport_free_pkt(pkt); |
|---|
| 300 | 375 | } |
|---|
| 301 | 376 | } |
|---|
| 377 | + |
|---|
| 378 | + free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt); |
|---|
| 379 | + |
|---|
| 302 | 380 | spin_unlock_bh(&vvs->rx_lock); |
|---|
| 303 | 381 | |
|---|
| 304 | | - /* Send a credit pkt to peer */ |
|---|
| 305 | | - virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, |
|---|
| 306 | | - NULL); |
|---|
| 382 | + /* To reduce the number of credit update messages, |
|---|
| 383 | + * don't update credits as long as lots of space is available. |
|---|
| 384 | + * Note: the limit chosen here is arbitrary. Setting the limit |
|---|
| 385 | + * too high causes extra messages. Too low causes transmitter |
|---|
| 386 | + * stalls. As stalls are in theory more expensive than extra |
|---|
| 387 | + * messages, we set the limit to a high value. TODO: experiment |
|---|
| 388 | + * with different values. |
|---|
| 389 | + */ |
|---|
| 390 | + if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { |
|---|
| 391 | + virtio_transport_send_credit_update(vsk, |
|---|
| 392 | + VIRTIO_VSOCK_TYPE_STREAM, |
|---|
| 393 | + NULL); |
|---|
| 394 | + } |
|---|
| 307 | 395 | |
|---|
| 308 | 396 | return total; |
|---|
| 309 | 397 | |
|---|
| .. | .. |
|---|
| 319 | 407 | size_t len, int flags) |
|---|
| 320 | 408 | { |
|---|
| 321 | 409 | if (flags & MSG_PEEK) |
|---|
| 322 | | - return -EOPNOTSUPP; |
|---|
| 323 | | - |
|---|
| 324 | | - return virtio_transport_stream_do_dequeue(vsk, msg, len); |
|---|
| 410 | + return virtio_transport_stream_do_peek(vsk, msg, len); |
|---|
| 411 | + else |
|---|
| 412 | + return virtio_transport_stream_do_dequeue(vsk, msg, len); |
|---|
| 325 | 413 | } |
|---|
| 326 | 414 | EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); |
|---|
| 327 | 415 | |
|---|
| .. | .. |
|---|
| 383 | 471 | |
|---|
| 384 | 472 | vsk->trans = vvs; |
|---|
| 385 | 473 | vvs->vsk = vsk; |
|---|
| 386 | | - if (psk) { |
|---|
| 474 | + if (psk && psk->trans) { |
|---|
| 387 | 475 | struct virtio_vsock_sock *ptrans = psk->trans; |
|---|
| 388 | 476 | |
|---|
| 389 | | - vvs->buf_size = ptrans->buf_size; |
|---|
| 390 | | - vvs->buf_size_min = ptrans->buf_size_min; |
|---|
| 391 | | - vvs->buf_size_max = ptrans->buf_size_max; |
|---|
| 392 | 477 | vvs->peer_buf_alloc = ptrans->peer_buf_alloc; |
|---|
| 393 | | - } else { |
|---|
| 394 | | - vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; |
|---|
| 395 | | - vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; |
|---|
| 396 | | - vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; |
|---|
| 397 | 478 | } |
|---|
| 398 | 479 | |
|---|
| 399 | | - vvs->buf_alloc = vvs->buf_size; |
|---|
| 480 | + if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE) |
|---|
| 481 | + vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE; |
|---|
| 482 | + |
|---|
| 483 | + vvs->buf_alloc = vsk->buffer_size; |
|---|
| 400 | 484 | |
|---|
| 401 | 485 | spin_lock_init(&vvs->rx_lock); |
|---|
| 402 | 486 | spin_lock_init(&vvs->tx_lock); |
|---|
| .. | .. |
|---|
| 406 | 490 | } |
|---|
| 407 | 491 | EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); |
|---|
| 408 | 492 | |
|---|
| 409 | | -u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) |
|---|
| 493 | +/* sk_lock held by the caller */ |
|---|
| 494 | +void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val) |
|---|
| 410 | 495 | { |
|---|
| 411 | 496 | struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 412 | 497 | |
|---|
| 413 | | - return vvs->buf_size; |
|---|
| 498 | + if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE) |
|---|
| 499 | + *val = VIRTIO_VSOCK_MAX_BUF_SIZE; |
|---|
| 500 | + |
|---|
| 501 | + vvs->buf_alloc = *val; |
|---|
| 502 | + |
|---|
| 503 | + virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, |
|---|
| 504 | + NULL); |
|---|
| 414 | 505 | } |
|---|
| 415 | | -EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); |
|---|
| 416 | | - |
|---|
| 417 | | -u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) |
|---|
| 418 | | -{ |
|---|
| 419 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 420 | | - |
|---|
| 421 | | - return vvs->buf_size_min; |
|---|
| 422 | | -} |
|---|
| 423 | | -EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); |
|---|
| 424 | | - |
|---|
| 425 | | -u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) |
|---|
| 426 | | -{ |
|---|
| 427 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 428 | | - |
|---|
| 429 | | - return vvs->buf_size_max; |
|---|
| 430 | | -} |
|---|
| 431 | | -EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); |
|---|
| 432 | | - |
|---|
| 433 | | -void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) |
|---|
| 434 | | -{ |
|---|
| 435 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 436 | | - |
|---|
| 437 | | - if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) |
|---|
| 438 | | - val = VIRTIO_VSOCK_MAX_BUF_SIZE; |
|---|
| 439 | | - if (val < vvs->buf_size_min) |
|---|
| 440 | | - vvs->buf_size_min = val; |
|---|
| 441 | | - if (val > vvs->buf_size_max) |
|---|
| 442 | | - vvs->buf_size_max = val; |
|---|
| 443 | | - vvs->buf_size = val; |
|---|
| 444 | | - vvs->buf_alloc = val; |
|---|
| 445 | | -} |
|---|
| 446 | | -EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); |
|---|
| 447 | | - |
|---|
| 448 | | -void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) |
|---|
| 449 | | -{ |
|---|
| 450 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 451 | | - |
|---|
| 452 | | - if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) |
|---|
| 453 | | - val = VIRTIO_VSOCK_MAX_BUF_SIZE; |
|---|
| 454 | | - if (val > vvs->buf_size) |
|---|
| 455 | | - vvs->buf_size = val; |
|---|
| 456 | | - vvs->buf_size_min = val; |
|---|
| 457 | | -} |
|---|
| 458 | | -EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); |
|---|
| 459 | | - |
|---|
| 460 | | -void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) |
|---|
| 461 | | -{ |
|---|
| 462 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 463 | | - |
|---|
| 464 | | - if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) |
|---|
| 465 | | - val = VIRTIO_VSOCK_MAX_BUF_SIZE; |
|---|
| 466 | | - if (val < vvs->buf_size) |
|---|
| 467 | | - vvs->buf_size = val; |
|---|
| 468 | | - vvs->buf_size_max = val; |
|---|
| 469 | | -} |
|---|
| 470 | | -EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); |
|---|
| 506 | +EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size); |
|---|
| 471 | 507 | |
|---|
| 472 | 508 | int |
|---|
| 473 | 509 | virtio_transport_notify_poll_in(struct vsock_sock *vsk, |
|---|
| .. | .. |
|---|
| 559 | 595 | |
|---|
| 560 | 596 | u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) |
|---|
| 561 | 597 | { |
|---|
| 562 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 563 | | - |
|---|
| 564 | | - return vvs->buf_size; |
|---|
| 598 | + return vsk->buffer_size; |
|---|
| 565 | 599 | } |
|---|
| 566 | 600 | EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); |
|---|
| 567 | 601 | |
|---|
| .. | .. |
|---|
| 703 | 737 | return t->send_pkt(reply); |
|---|
| 704 | 738 | } |
|---|
| 705 | 739 | |
|---|
| 740 | +/* This function should be called with sk_lock held and SOCK_DONE set */ |
|---|
| 741 | +static void virtio_transport_remove_sock(struct vsock_sock *vsk) |
|---|
| 742 | +{ |
|---|
| 743 | + struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 744 | + struct virtio_vsock_pkt *pkt, *tmp; |
|---|
| 745 | + |
|---|
| 746 | + /* We don't need to take rx_lock, as the socket is closing and we are |
|---|
| 747 | + * removing it. |
|---|
| 748 | + */ |
|---|
| 749 | + list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { |
|---|
| 750 | + list_del(&pkt->list); |
|---|
| 751 | + virtio_transport_free_pkt(pkt); |
|---|
| 752 | + } |
|---|
| 753 | + |
|---|
| 754 | + vsock_remove_sock(vsk); |
|---|
| 755 | +} |
|---|
| 756 | + |
|---|
| 706 | 757 | static void virtio_transport_wait_close(struct sock *sk, long timeout) |
|---|
| 707 | 758 | { |
|---|
| 708 | 759 | if (timeout) { |
|---|
| .. | .. |
|---|
| 735 | 786 | (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { |
|---|
| 736 | 787 | vsk->close_work_scheduled = false; |
|---|
| 737 | 788 | |
|---|
| 738 | | - vsock_remove_sock(vsk); |
|---|
| 789 | + virtio_transport_remove_sock(vsk); |
|---|
| 739 | 790 | |
|---|
| 740 | 791 | /* Release refcnt obtained when we scheduled the timeout */ |
|---|
| 741 | 792 | sock_put(sk); |
|---|
| .. | .. |
|---|
| 798 | 849 | |
|---|
| 799 | 850 | void virtio_transport_release(struct vsock_sock *vsk) |
|---|
| 800 | 851 | { |
|---|
| 801 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 802 | | - struct virtio_vsock_pkt *pkt, *tmp; |
|---|
| 803 | 852 | struct sock *sk = &vsk->sk; |
|---|
| 804 | 853 | bool remove_sock = true; |
|---|
| 805 | 854 | |
|---|
| 806 | | - lock_sock_nested(sk, SINGLE_DEPTH_NESTING); |
|---|
| 807 | 855 | if (sk->sk_type == SOCK_STREAM) |
|---|
| 808 | 856 | remove_sock = virtio_transport_close(vsk); |
|---|
| 809 | 857 | |
|---|
| 810 | | - list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { |
|---|
| 811 | | - list_del(&pkt->list); |
|---|
| 812 | | - virtio_transport_free_pkt(pkt); |
|---|
| 858 | + if (remove_sock) { |
|---|
| 859 | + sock_set_flag(sk, SOCK_DONE); |
|---|
| 860 | + virtio_transport_remove_sock(vsk); |
|---|
| 813 | 861 | } |
|---|
| 814 | | - release_sock(sk); |
|---|
| 815 | | - |
|---|
| 816 | | - if (remove_sock) |
|---|
| 817 | | - vsock_remove_sock(vsk); |
|---|
| 818 | 862 | } |
|---|
| 819 | 863 | EXPORT_SYMBOL_GPL(virtio_transport_release); |
|---|
| 820 | 864 | |
|---|
| .. | .. |
|---|
| 854 | 898 | return err; |
|---|
| 855 | 899 | } |
|---|
| 856 | 900 | |
|---|
| 901 | +static void |
|---|
| 902 | +virtio_transport_recv_enqueue(struct vsock_sock *vsk, |
|---|
| 903 | + struct virtio_vsock_pkt *pkt) |
|---|
| 904 | +{ |
|---|
| 905 | + struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 906 | + bool can_enqueue, free_pkt = false; |
|---|
| 907 | + |
|---|
| 908 | + pkt->len = le32_to_cpu(pkt->hdr.len); |
|---|
| 909 | + pkt->off = 0; |
|---|
| 910 | + |
|---|
| 911 | + spin_lock_bh(&vvs->rx_lock); |
|---|
| 912 | + |
|---|
| 913 | + can_enqueue = virtio_transport_inc_rx_pkt(vvs, pkt); |
|---|
| 914 | + if (!can_enqueue) { |
|---|
| 915 | + free_pkt = true; |
|---|
| 916 | + goto out; |
|---|
| 917 | + } |
|---|
| 918 | + |
|---|
| 919 | + /* Try to copy small packets into the buffer of last packet queued, |
|---|
| 920 | + * to avoid wasting memory queueing the entire buffer with a small |
|---|
| 921 | + * payload. |
|---|
| 922 | + */ |
|---|
| 923 | + if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) { |
|---|
| 924 | + struct virtio_vsock_pkt *last_pkt; |
|---|
| 925 | + |
|---|
| 926 | + last_pkt = list_last_entry(&vvs->rx_queue, |
|---|
| 927 | + struct virtio_vsock_pkt, list); |
|---|
| 928 | + |
|---|
| 929 | + /* If there is space in the last packet queued, we copy the |
|---|
| 930 | + * new packet in its buffer. |
|---|
| 931 | + */ |
|---|
| 932 | + if (pkt->len <= last_pkt->buf_len - last_pkt->len) { |
|---|
| 933 | + memcpy(last_pkt->buf + last_pkt->len, pkt->buf, |
|---|
| 934 | + pkt->len); |
|---|
| 935 | + last_pkt->len += pkt->len; |
|---|
| 936 | + free_pkt = true; |
|---|
| 937 | + goto out; |
|---|
| 938 | + } |
|---|
| 939 | + } |
|---|
| 940 | + |
|---|
| 941 | + list_add_tail(&pkt->list, &vvs->rx_queue); |
|---|
| 942 | + |
|---|
| 943 | +out: |
|---|
| 944 | + spin_unlock_bh(&vvs->rx_lock); |
|---|
| 945 | + if (free_pkt) |
|---|
| 946 | + virtio_transport_free_pkt(pkt); |
|---|
| 947 | +} |
|---|
| 948 | + |
|---|
| 857 | 949 | static int |
|---|
| 858 | 950 | virtio_transport_recv_connected(struct sock *sk, |
|---|
| 859 | 951 | struct virtio_vsock_pkt *pkt) |
|---|
| 860 | 952 | { |
|---|
| 861 | 953 | struct vsock_sock *vsk = vsock_sk(sk); |
|---|
| 862 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 863 | 954 | int err = 0; |
|---|
| 864 | 955 | |
|---|
| 865 | 956 | switch (le16_to_cpu(pkt->hdr.op)) { |
|---|
| 866 | 957 | case VIRTIO_VSOCK_OP_RW: |
|---|
| 867 | | - pkt->len = le32_to_cpu(pkt->hdr.len); |
|---|
| 868 | | - pkt->off = 0; |
|---|
| 869 | | - |
|---|
| 870 | | - spin_lock_bh(&vvs->rx_lock); |
|---|
| 871 | | - virtio_transport_inc_rx_pkt(vvs, pkt); |
|---|
| 872 | | - list_add_tail(&pkt->list, &vvs->rx_queue); |
|---|
| 873 | | - spin_unlock_bh(&vvs->rx_lock); |
|---|
| 874 | | - |
|---|
| 958 | + virtio_transport_recv_enqueue(vsk, pkt); |
|---|
| 875 | 959 | sk->sk_data_ready(sk); |
|---|
| 876 | 960 | return err; |
|---|
| 877 | 961 | case VIRTIO_VSOCK_OP_CREDIT_UPDATE: |
|---|
| .. | .. |
|---|
| 930 | 1014 | return virtio_transport_send_pkt_info(vsk, &info); |
|---|
| 931 | 1015 | } |
|---|
| 932 | 1016 | |
|---|
| 1017 | +static bool virtio_transport_space_update(struct sock *sk, |
|---|
| 1018 | + struct virtio_vsock_pkt *pkt) |
|---|
| 1019 | +{ |
|---|
| 1020 | + struct vsock_sock *vsk = vsock_sk(sk); |
|---|
| 1021 | + struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 1022 | + bool space_available; |
|---|
| 1023 | + |
|---|
| 1024 | + /* Listener sockets are not associated with any transport, so we are |
|---|
| 1025 | + * not able to take the state to see if there is space available in the |
|---|
| 1026 | + * remote peer, but since they are only used to receive requests, we |
|---|
| 1027 | + * can assume that there is always space available in the other peer. |
|---|
| 1028 | + */ |
|---|
| 1029 | + if (!vvs) |
|---|
| 1030 | + return true; |
|---|
| 1031 | + |
|---|
| 1032 | + /* buf_alloc and fwd_cnt is always included in the hdr */ |
|---|
| 1033 | + spin_lock_bh(&vvs->tx_lock); |
|---|
| 1034 | + vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); |
|---|
| 1035 | + vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); |
|---|
| 1036 | + space_available = virtio_transport_has_space(vsk); |
|---|
| 1037 | + spin_unlock_bh(&vvs->tx_lock); |
|---|
| 1038 | + return space_available; |
|---|
| 1039 | +} |
|---|
| 1040 | + |
|---|
| 933 | 1041 | /* Handle server socket */ |
|---|
| 934 | 1042 | static int |
|---|
| 935 | | -virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) |
|---|
| 1043 | +virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt, |
|---|
| 1044 | + struct virtio_transport *t) |
|---|
| 936 | 1045 | { |
|---|
| 937 | 1046 | struct vsock_sock *vsk = vsock_sk(sk); |
|---|
| 938 | 1047 | struct vsock_sock *vchild; |
|---|
| 939 | 1048 | struct sock *child; |
|---|
| 1049 | + int ret; |
|---|
| 940 | 1050 | |
|---|
| 941 | 1051 | if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { |
|---|
| 942 | | - virtio_transport_reset(vsk, pkt); |
|---|
| 1052 | + virtio_transport_reset_no_sock(t, pkt); |
|---|
| 943 | 1053 | return -EINVAL; |
|---|
| 944 | 1054 | } |
|---|
| 945 | 1055 | |
|---|
| 946 | 1056 | if (sk_acceptq_is_full(sk)) { |
|---|
| 947 | | - virtio_transport_reset(vsk, pkt); |
|---|
| 1057 | + virtio_transport_reset_no_sock(t, pkt); |
|---|
| 948 | 1058 | return -ENOMEM; |
|---|
| 949 | 1059 | } |
|---|
| 950 | 1060 | |
|---|
| 951 | | - child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, |
|---|
| 952 | | - sk->sk_type, 0); |
|---|
| 1061 | + child = vsock_create_connected(sk); |
|---|
| 953 | 1062 | if (!child) { |
|---|
| 954 | | - virtio_transport_reset(vsk, pkt); |
|---|
| 1063 | + virtio_transport_reset_no_sock(t, pkt); |
|---|
| 955 | 1064 | return -ENOMEM; |
|---|
| 956 | 1065 | } |
|---|
| 957 | 1066 | |
|---|
| 958 | | - sk->sk_ack_backlog++; |
|---|
| 1067 | + sk_acceptq_added(sk); |
|---|
| 959 | 1068 | |
|---|
| 960 | 1069 | lock_sock_nested(child, SINGLE_DEPTH_NESTING); |
|---|
| 961 | 1070 | |
|---|
| .. | .. |
|---|
| 967 | 1076 | vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), |
|---|
| 968 | 1077 | le32_to_cpu(pkt->hdr.src_port)); |
|---|
| 969 | 1078 | |
|---|
| 1079 | + ret = vsock_assign_transport(vchild, vsk); |
|---|
| 1080 | + /* Transport assigned (looking at remote_addr) must be the same |
|---|
| 1081 | + * where we received the request. |
|---|
| 1082 | + */ |
|---|
| 1083 | + if (ret || vchild->transport != &t->transport) { |
|---|
| 1084 | + release_sock(child); |
|---|
| 1085 | + virtio_transport_reset_no_sock(t, pkt); |
|---|
| 1086 | + sock_put(child); |
|---|
| 1087 | + return ret; |
|---|
| 1088 | + } |
|---|
| 1089 | + |
|---|
| 1090 | + if (virtio_transport_space_update(child, pkt)) |
|---|
| 1091 | + child->sk_write_space(child); |
|---|
| 1092 | + |
|---|
| 970 | 1093 | vsock_insert_connected(vchild); |
|---|
| 971 | 1094 | vsock_enqueue_accept(sk, child); |
|---|
| 972 | 1095 | virtio_transport_send_response(vchild, pkt); |
|---|
| .. | .. |
|---|
| 975 | 1098 | |
|---|
| 976 | 1099 | sk->sk_data_ready(sk); |
|---|
| 977 | 1100 | return 0; |
|---|
| 978 | | -} |
|---|
| 979 | | - |
|---|
| 980 | | -static bool virtio_transport_space_update(struct sock *sk, |
|---|
| 981 | | - struct virtio_vsock_pkt *pkt) |
|---|
| 982 | | -{ |
|---|
| 983 | | - struct vsock_sock *vsk = vsock_sk(sk); |
|---|
| 984 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
|---|
| 985 | | - bool space_available; |
|---|
| 986 | | - |
|---|
| 987 | | - /* buf_alloc and fwd_cnt is always included in the hdr */ |
|---|
| 988 | | - spin_lock_bh(&vvs->tx_lock); |
|---|
| 989 | | - vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); |
|---|
| 990 | | - vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); |
|---|
| 991 | | - space_available = virtio_transport_has_space(vsk); |
|---|
| 992 | | - spin_unlock_bh(&vvs->tx_lock); |
|---|
| 993 | | - return space_available; |
|---|
| 994 | 1101 | } |
|---|
| 995 | 1102 | |
|---|
| 996 | 1103 | /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex |
|---|
| .. | .. |
|---|
| 1039 | 1146 | |
|---|
| 1040 | 1147 | lock_sock(sk); |
|---|
| 1041 | 1148 | |
|---|
| 1149 | + /* Check if sk has been closed before lock_sock */ |
|---|
| 1150 | + if (sock_flag(sk, SOCK_DONE)) { |
|---|
| 1151 | + (void)virtio_transport_reset_no_sock(t, pkt); |
|---|
| 1152 | + release_sock(sk); |
|---|
| 1153 | + sock_put(sk); |
|---|
| 1154 | + goto free_pkt; |
|---|
| 1155 | + } |
|---|
| 1156 | + |
|---|
| 1042 | 1157 | space_available = virtio_transport_space_update(sk, pkt); |
|---|
| 1043 | 1158 | |
|---|
| 1044 | 1159 | /* Update CID in case it has changed after a transport reset event */ |
|---|
| 1045 | | - vsk->local_addr.svm_cid = dst.svm_cid; |
|---|
| 1160 | + if (vsk->local_addr.svm_cid != VMADDR_CID_ANY) |
|---|
| 1161 | + vsk->local_addr.svm_cid = dst.svm_cid; |
|---|
| 1046 | 1162 | |
|---|
| 1047 | 1163 | if (space_available) |
|---|
| 1048 | 1164 | sk->sk_write_space(sk); |
|---|
| 1049 | 1165 | |
|---|
| 1050 | 1166 | switch (sk->sk_state) { |
|---|
| 1051 | 1167 | case TCP_LISTEN: |
|---|
| 1052 | | - virtio_transport_recv_listen(sk, pkt); |
|---|
| 1168 | + virtio_transport_recv_listen(sk, pkt, t); |
|---|
| 1053 | 1169 | virtio_transport_free_pkt(pkt); |
|---|
| 1054 | 1170 | break; |
|---|
| 1055 | 1171 | case TCP_SYN_SENT: |
|---|
| .. | .. |
|---|
| 1068 | 1184 | virtio_transport_free_pkt(pkt); |
|---|
| 1069 | 1185 | break; |
|---|
| 1070 | 1186 | } |
|---|
| 1187 | + |
|---|
| 1071 | 1188 | release_sock(sk); |
|---|
| 1072 | 1189 | |
|---|
| 1073 | 1190 | /* Release refcnt obtained when we fetched this socket out of the |
|---|
| .. | .. |
|---|
| 1083 | 1200 | |
|---|
| 1084 | 1201 | void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) |
|---|
| 1085 | 1202 | { |
|---|
| 1086 | | - kfree(pkt->buf); |
|---|
| 1203 | + kvfree(pkt->buf); |
|---|
| 1087 | 1204 | kfree(pkt); |
|---|
| 1088 | 1205 | } |
|---|
| 1089 | 1206 | EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); |
|---|