.. | .. |
---|
| 1 | +// SPDX-License-Identifier: GPL-2.0-only |
---|
1 | 2 | /* |
---|
2 | 3 | * common code for virtio vsock |
---|
3 | 4 | * |
---|
4 | 5 | * Copyright (C) 2013-2015 Red Hat, Inc. |
---|
5 | 6 | * Author: Asias He <asias@redhat.com> |
---|
6 | 7 | * Stefan Hajnoczi <stefanha@redhat.com> |
---|
7 | | - * |
---|
8 | | - * This work is licensed under the terms of the GNU GPL, version 2. |
---|
9 | 8 | */ |
---|
10 | 9 | #include <linux/spinlock.h> |
---|
11 | 10 | #include <linux/module.h> |
---|
12 | 11 | #include <linux/sched/signal.h> |
---|
13 | 12 | #include <linux/ctype.h> |
---|
14 | 13 | #include <linux/list.h> |
---|
15 | | -#include <linux/virtio.h> |
---|
16 | | -#include <linux/virtio_ids.h> |
---|
17 | | -#include <linux/virtio_config.h> |
---|
18 | 14 | #include <linux/virtio_vsock.h> |
---|
19 | 15 | #include <uapi/linux/vsockmon.h> |
---|
20 | 16 | |
---|
.. | .. |
---|
27 | 23 | /* How long to wait for graceful shutdown of a connection */ |
---|
28 | 24 | #define VSOCK_CLOSE_TIMEOUT (8 * HZ) |
---|
29 | 25 | |
---|
| 26 | +/* Threshold for detecting small packets to copy */ |
---|
| 27 | +#define GOOD_COPY_LEN 128 |
---|
| 28 | + |
---|
30 | 29 | uint virtio_transport_max_vsock_pkt_buf_size = 64 * 1024; |
---|
31 | 30 | module_param(virtio_transport_max_vsock_pkt_buf_size, uint, 0444); |
---|
32 | 31 | EXPORT_SYMBOL_GPL(virtio_transport_max_vsock_pkt_buf_size); |
---|
33 | 32 | |
---|
34 | | -static const struct virtio_transport *virtio_transport_get_ops(void) |
---|
| 33 | +static const struct virtio_transport * |
---|
| 34 | +virtio_transport_get_ops(struct vsock_sock *vsk) |
---|
35 | 35 | { |
---|
36 | | - const struct vsock_transport *t = vsock_core_get_transport(); |
---|
| 36 | + const struct vsock_transport *t = vsock_core_get_transport(vsk); |
---|
| 37 | + |
---|
| 38 | + if (WARN_ON(!t)) |
---|
| 39 | + return NULL; |
---|
37 | 40 | |
---|
38 | 41 | return container_of(t, struct virtio_transport, transport); |
---|
39 | 42 | } |
---|
.. | .. |
---|
69 | 72 | pkt->buf = kmalloc(len, GFP_KERNEL); |
---|
70 | 73 | if (!pkt->buf) |
---|
71 | 74 | goto out_pkt; |
---|
| 75 | + |
---|
| 76 | + pkt->buf_len = len; |
---|
| 77 | + |
---|
72 | 78 | err = memcpy_from_msg(pkt->buf, info->msg, len); |
---|
73 | 79 | if (err) |
---|
74 | 80 | goto out; |
---|
.. | .. |
---|
155 | 161 | |
---|
156 | 162 | void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt) |
---|
157 | 163 | { |
---|
| 164 | + if (pkt->tap_delivered) |
---|
| 165 | + return; |
---|
| 166 | + |
---|
158 | 167 | vsock_deliver_tap(virtio_transport_build_skb, pkt); |
---|
| 168 | + pkt->tap_delivered = true; |
---|
159 | 169 | } |
---|
160 | 170 | EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); |
---|
161 | 171 | |
---|
| 172 | +/* This function can only be used on connecting/connected sockets, |
---|
| 173 | + * since a socket assigned to a transport is required. |
---|
| 174 | + * |
---|
| 175 | + * Do not use on listener sockets! |
---|
| 176 | + */ |
---|
162 | 177 | static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, |
---|
163 | 178 | struct virtio_vsock_pkt_info *info) |
---|
164 | 179 | { |
---|
165 | 180 | u32 src_cid, src_port, dst_cid, dst_port; |
---|
| 181 | + const struct virtio_transport *t_ops; |
---|
166 | 182 | struct virtio_vsock_sock *vvs; |
---|
167 | 183 | struct virtio_vsock_pkt *pkt; |
---|
168 | 184 | u32 pkt_len = info->pkt_len; |
---|
169 | 185 | |
---|
170 | | - src_cid = vm_sockets_get_local_cid(); |
---|
| 186 | + t_ops = virtio_transport_get_ops(vsk); |
---|
| 187 | + if (unlikely(!t_ops)) |
---|
| 188 | + return -EFAULT; |
---|
| 189 | + |
---|
| 190 | + src_cid = t_ops->transport.get_local_cid(); |
---|
171 | 191 | src_port = vsk->local_addr.svm_port; |
---|
172 | 192 | if (!info->remote_cid) { |
---|
173 | 193 | dst_cid = vsk->remote_addr.svm_cid; |
---|
.. | .. |
---|
180 | 200 | vvs = vsk->trans; |
---|
181 | 201 | |
---|
182 | 202 | /* we can send less than pkt_len bytes */ |
---|
183 | | - if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) |
---|
184 | | - pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; |
---|
| 203 | + if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) |
---|
| 204 | + pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE; |
---|
185 | 205 | |
---|
186 | 206 | /* virtio_transport_get_credit might return less than pkt_len credit */ |
---|
187 | 207 | pkt_len = virtio_transport_get_credit(vvs, pkt_len); |
---|
.. | .. |
---|
200 | 220 | |
---|
201 | 221 | virtio_transport_inc_tx_pkt(vvs, pkt); |
---|
202 | 222 | |
---|
203 | | - return virtio_transport_get_ops()->send_pkt(pkt); |
---|
| 223 | + return t_ops->send_pkt(pkt); |
---|
204 | 224 | } |
---|
205 | 225 | |
---|
206 | | -static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, |
---|
| 226 | +static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, |
---|
207 | 227 | struct virtio_vsock_pkt *pkt) |
---|
208 | 228 | { |
---|
| 229 | + if (vvs->rx_bytes + pkt->len > vvs->buf_alloc) |
---|
| 230 | + return false; |
---|
| 231 | + |
---|
209 | 232 | vvs->rx_bytes += pkt->len; |
---|
| 233 | + return true; |
---|
210 | 234 | } |
---|
211 | 235 | |
---|
212 | 236 | static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, |
---|
.. | .. |
---|
218 | 242 | |
---|
219 | 243 | void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) |
---|
220 | 244 | { |
---|
221 | | - spin_lock_bh(&vvs->tx_lock); |
---|
| 245 | + spin_lock_bh(&vvs->rx_lock); |
---|
| 246 | + vvs->last_fwd_cnt = vvs->fwd_cnt; |
---|
222 | 247 | pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); |
---|
223 | 248 | pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); |
---|
224 | | - spin_unlock_bh(&vvs->tx_lock); |
---|
| 249 | + spin_unlock_bh(&vvs->rx_lock); |
---|
225 | 250 | } |
---|
226 | 251 | EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); |
---|
227 | 252 | |
---|
.. | .. |
---|
262 | 287 | } |
---|
263 | 288 | |
---|
264 | 289 | static ssize_t |
---|
| 290 | +virtio_transport_stream_do_peek(struct vsock_sock *vsk, |
---|
| 291 | + struct msghdr *msg, |
---|
| 292 | + size_t len) |
---|
| 293 | +{ |
---|
| 294 | + struct virtio_vsock_sock *vvs = vsk->trans; |
---|
| 295 | + struct virtio_vsock_pkt *pkt; |
---|
| 296 | + size_t bytes, total = 0, off; |
---|
| 297 | + int err = -EFAULT; |
---|
| 298 | + |
---|
| 299 | + spin_lock_bh(&vvs->rx_lock); |
---|
| 300 | + |
---|
| 301 | + list_for_each_entry(pkt, &vvs->rx_queue, list) { |
---|
| 302 | + off = pkt->off; |
---|
| 303 | + |
---|
| 304 | + if (total == len) |
---|
| 305 | + break; |
---|
| 306 | + |
---|
| 307 | + while (total < len && off < pkt->len) { |
---|
| 308 | + bytes = len - total; |
---|
| 309 | + if (bytes > pkt->len - off) |
---|
| 310 | + bytes = pkt->len - off; |
---|
| 311 | + |
---|
| 312 | + /* sk_lock is held by caller so no one else can dequeue. |
---|
| 313 | + * Unlock rx_lock since memcpy_to_msg() may sleep. |
---|
| 314 | + */ |
---|
| 315 | + spin_unlock_bh(&vvs->rx_lock); |
---|
| 316 | + |
---|
| 317 | + err = memcpy_to_msg(msg, pkt->buf + off, bytes); |
---|
| 318 | + if (err) |
---|
| 319 | + goto out; |
---|
| 320 | + |
---|
| 321 | + spin_lock_bh(&vvs->rx_lock); |
---|
| 322 | + |
---|
| 323 | + total += bytes; |
---|
| 324 | + off += bytes; |
---|
| 325 | + } |
---|
| 326 | + } |
---|
| 327 | + |
---|
| 328 | + spin_unlock_bh(&vvs->rx_lock); |
---|
| 329 | + |
---|
| 330 | + return total; |
---|
| 331 | + |
---|
| 332 | +out: |
---|
| 333 | + if (total) |
---|
| 334 | + err = total; |
---|
| 335 | + return err; |
---|
| 336 | +} |
---|
| 337 | + |
---|
| 338 | +static ssize_t |
---|
265 | 339 | virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, |
---|
266 | 340 | struct msghdr *msg, |
---|
267 | 341 | size_t len) |
---|
.. | .. |
---|
269 | 343 | struct virtio_vsock_sock *vvs = vsk->trans; |
---|
270 | 344 | struct virtio_vsock_pkt *pkt; |
---|
271 | 345 | size_t bytes, total = 0; |
---|
| 346 | + u32 free_space; |
---|
272 | 347 | int err = -EFAULT; |
---|
273 | 348 | |
---|
274 | 349 | spin_lock_bh(&vvs->rx_lock); |
---|
.. | .. |
---|
299 | 374 | virtio_transport_free_pkt(pkt); |
---|
300 | 375 | } |
---|
301 | 376 | } |
---|
| 377 | + |
---|
| 378 | + free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt); |
---|
| 379 | + |
---|
302 | 380 | spin_unlock_bh(&vvs->rx_lock); |
---|
303 | 381 | |
---|
304 | | - /* Send a credit pkt to peer */ |
---|
305 | | - virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, |
---|
306 | | - NULL); |
---|
| 382 | + /* To reduce the number of credit update messages, |
---|
| 383 | + * don't update credits as long as lots of space is available. |
---|
| 384 | + * Note: the limit chosen here is arbitrary. Setting the limit |
---|
| 385 | + * too high causes extra messages. Too low causes transmitter |
---|
| 386 | + * stalls. As stalls are in theory more expensive than extra |
---|
| 387 | + * messages, we set the limit to a high value. TODO: experiment |
---|
| 388 | + * with different values. |
---|
| 389 | + */ |
---|
| 390 | + if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { |
---|
| 391 | + virtio_transport_send_credit_update(vsk, |
---|
| 392 | + VIRTIO_VSOCK_TYPE_STREAM, |
---|
| 393 | + NULL); |
---|
| 394 | + } |
---|
307 | 395 | |
---|
308 | 396 | return total; |
---|
309 | 397 | |
---|
.. | .. |
---|
319 | 407 | size_t len, int flags) |
---|
320 | 408 | { |
---|
321 | 409 | if (flags & MSG_PEEK) |
---|
322 | | - return -EOPNOTSUPP; |
---|
323 | | - |
---|
324 | | - return virtio_transport_stream_do_dequeue(vsk, msg, len); |
---|
| 410 | + return virtio_transport_stream_do_peek(vsk, msg, len); |
---|
| 411 | + else |
---|
| 412 | + return virtio_transport_stream_do_dequeue(vsk, msg, len); |
---|
325 | 413 | } |
---|
326 | 414 | EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); |
---|
327 | 415 | |
---|
.. | .. |
---|
383 | 471 | |
---|
384 | 472 | vsk->trans = vvs; |
---|
385 | 473 | vvs->vsk = vsk; |
---|
386 | | - if (psk) { |
---|
| 474 | + if (psk && psk->trans) { |
---|
387 | 475 | struct virtio_vsock_sock *ptrans = psk->trans; |
---|
388 | 476 | |
---|
389 | | - vvs->buf_size = ptrans->buf_size; |
---|
390 | | - vvs->buf_size_min = ptrans->buf_size_min; |
---|
391 | | - vvs->buf_size_max = ptrans->buf_size_max; |
---|
392 | 477 | vvs->peer_buf_alloc = ptrans->peer_buf_alloc; |
---|
393 | | - } else { |
---|
394 | | - vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; |
---|
395 | | - vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; |
---|
396 | | - vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; |
---|
397 | 478 | } |
---|
398 | 479 | |
---|
399 | | - vvs->buf_alloc = vvs->buf_size; |
---|
| 480 | + if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE) |
---|
| 481 | + vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE; |
---|
| 482 | + |
---|
| 483 | + vvs->buf_alloc = vsk->buffer_size; |
---|
400 | 484 | |
---|
401 | 485 | spin_lock_init(&vvs->rx_lock); |
---|
402 | 486 | spin_lock_init(&vvs->tx_lock); |
---|
.. | .. |
---|
406 | 490 | } |
---|
407 | 491 | EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); |
---|
408 | 492 | |
---|
409 | | -u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) |
---|
| 493 | +/* sk_lock held by the caller */ |
---|
| 494 | +void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val) |
---|
410 | 495 | { |
---|
411 | 496 | struct virtio_vsock_sock *vvs = vsk->trans; |
---|
412 | 497 | |
---|
413 | | - return vvs->buf_size; |
---|
| 498 | + if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE) |
---|
| 499 | + *val = VIRTIO_VSOCK_MAX_BUF_SIZE; |
---|
| 500 | + |
---|
| 501 | + vvs->buf_alloc = *val; |
---|
| 502 | + |
---|
| 503 | + virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, |
---|
| 504 | + NULL); |
---|
414 | 505 | } |
---|
415 | | -EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); |
---|
416 | | - |
---|
417 | | -u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) |
---|
418 | | -{ |
---|
419 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
---|
420 | | - |
---|
421 | | - return vvs->buf_size_min; |
---|
422 | | -} |
---|
423 | | -EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); |
---|
424 | | - |
---|
425 | | -u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) |
---|
426 | | -{ |
---|
427 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
---|
428 | | - |
---|
429 | | - return vvs->buf_size_max; |
---|
430 | | -} |
---|
431 | | -EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); |
---|
432 | | - |
---|
433 | | -void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) |
---|
434 | | -{ |
---|
435 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
---|
436 | | - |
---|
437 | | - if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) |
---|
438 | | - val = VIRTIO_VSOCK_MAX_BUF_SIZE; |
---|
439 | | - if (val < vvs->buf_size_min) |
---|
440 | | - vvs->buf_size_min = val; |
---|
441 | | - if (val > vvs->buf_size_max) |
---|
442 | | - vvs->buf_size_max = val; |
---|
443 | | - vvs->buf_size = val; |
---|
444 | | - vvs->buf_alloc = val; |
---|
445 | | -} |
---|
446 | | -EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); |
---|
447 | | - |
---|
448 | | -void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) |
---|
449 | | -{ |
---|
450 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
---|
451 | | - |
---|
452 | | - if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) |
---|
453 | | - val = VIRTIO_VSOCK_MAX_BUF_SIZE; |
---|
454 | | - if (val > vvs->buf_size) |
---|
455 | | - vvs->buf_size = val; |
---|
456 | | - vvs->buf_size_min = val; |
---|
457 | | -} |
---|
458 | | -EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); |
---|
459 | | - |
---|
460 | | -void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) |
---|
461 | | -{ |
---|
462 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
---|
463 | | - |
---|
464 | | - if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) |
---|
465 | | - val = VIRTIO_VSOCK_MAX_BUF_SIZE; |
---|
466 | | - if (val < vvs->buf_size) |
---|
467 | | - vvs->buf_size = val; |
---|
468 | | - vvs->buf_size_max = val; |
---|
469 | | -} |
---|
470 | | -EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); |
---|
| 506 | +EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size); |
---|
471 | 507 | |
---|
472 | 508 | int |
---|
473 | 509 | virtio_transport_notify_poll_in(struct vsock_sock *vsk, |
---|
.. | .. |
---|
559 | 595 | |
---|
560 | 596 | u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) |
---|
561 | 597 | { |
---|
562 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
---|
563 | | - |
---|
564 | | - return vvs->buf_size; |
---|
| 598 | + return vsk->buffer_size; |
---|
565 | 599 | } |
---|
566 | 600 | EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); |
---|
567 | 601 | |
---|
.. | .. |
---|
703 | 737 | return t->send_pkt(reply); |
---|
704 | 738 | } |
---|
705 | 739 | |
---|
| 740 | +/* This function should be called with sk_lock held and SOCK_DONE set */ |
---|
| 741 | +static void virtio_transport_remove_sock(struct vsock_sock *vsk) |
---|
| 742 | +{ |
---|
| 743 | + struct virtio_vsock_sock *vvs = vsk->trans; |
---|
| 744 | + struct virtio_vsock_pkt *pkt, *tmp; |
---|
| 745 | + |
---|
| 746 | + /* We don't need to take rx_lock, as the socket is closing and we are |
---|
| 747 | + * removing it. |
---|
| 748 | + */ |
---|
| 749 | + list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { |
---|
| 750 | + list_del(&pkt->list); |
---|
| 751 | + virtio_transport_free_pkt(pkt); |
---|
| 752 | + } |
---|
| 753 | + |
---|
| 754 | + vsock_remove_sock(vsk); |
---|
| 755 | +} |
---|
| 756 | + |
---|
706 | 757 | static void virtio_transport_wait_close(struct sock *sk, long timeout) |
---|
707 | 758 | { |
---|
708 | 759 | if (timeout) { |
---|
.. | .. |
---|
735 | 786 | (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { |
---|
736 | 787 | vsk->close_work_scheduled = false; |
---|
737 | 788 | |
---|
738 | | - vsock_remove_sock(vsk); |
---|
| 789 | + virtio_transport_remove_sock(vsk); |
---|
739 | 790 | |
---|
740 | 791 | /* Release refcnt obtained when we scheduled the timeout */ |
---|
741 | 792 | sock_put(sk); |
---|
.. | .. |
---|
798 | 849 | |
---|
799 | 850 | void virtio_transport_release(struct vsock_sock *vsk) |
---|
800 | 851 | { |
---|
801 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
---|
802 | | - struct virtio_vsock_pkt *pkt, *tmp; |
---|
803 | 852 | struct sock *sk = &vsk->sk; |
---|
804 | 853 | bool remove_sock = true; |
---|
805 | 854 | |
---|
806 | | - lock_sock_nested(sk, SINGLE_DEPTH_NESTING); |
---|
807 | 855 | if (sk->sk_type == SOCK_STREAM) |
---|
808 | 856 | remove_sock = virtio_transport_close(vsk); |
---|
809 | 857 | |
---|
810 | | - list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { |
---|
811 | | - list_del(&pkt->list); |
---|
812 | | - virtio_transport_free_pkt(pkt); |
---|
| 858 | + if (remove_sock) { |
---|
| 859 | + sock_set_flag(sk, SOCK_DONE); |
---|
| 860 | + virtio_transport_remove_sock(vsk); |
---|
813 | 861 | } |
---|
814 | | - release_sock(sk); |
---|
815 | | - |
---|
816 | | - if (remove_sock) |
---|
817 | | - vsock_remove_sock(vsk); |
---|
818 | 862 | } |
---|
819 | 863 | EXPORT_SYMBOL_GPL(virtio_transport_release); |
---|
820 | 864 | |
---|
.. | .. |
---|
854 | 898 | return err; |
---|
855 | 899 | } |
---|
856 | 900 | |
---|
| 901 | +static void |
---|
| 902 | +virtio_transport_recv_enqueue(struct vsock_sock *vsk, |
---|
| 903 | + struct virtio_vsock_pkt *pkt) |
---|
| 904 | +{ |
---|
| 905 | + struct virtio_vsock_sock *vvs = vsk->trans; |
---|
| 906 | + bool can_enqueue, free_pkt = false; |
---|
| 907 | + |
---|
| 908 | + pkt->len = le32_to_cpu(pkt->hdr.len); |
---|
| 909 | + pkt->off = 0; |
---|
| 910 | + |
---|
| 911 | + spin_lock_bh(&vvs->rx_lock); |
---|
| 912 | + |
---|
| 913 | + can_enqueue = virtio_transport_inc_rx_pkt(vvs, pkt); |
---|
| 914 | + if (!can_enqueue) { |
---|
| 915 | + free_pkt = true; |
---|
| 916 | + goto out; |
---|
| 917 | + } |
---|
| 918 | + |
---|
| 919 | + /* Try to copy small packets into the buffer of last packet queued, |
---|
| 920 | + * to avoid wasting memory queueing the entire buffer with a small |
---|
| 921 | + * payload. |
---|
| 922 | + */ |
---|
| 923 | + if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) { |
---|
| 924 | + struct virtio_vsock_pkt *last_pkt; |
---|
| 925 | + |
---|
| 926 | + last_pkt = list_last_entry(&vvs->rx_queue, |
---|
| 927 | + struct virtio_vsock_pkt, list); |
---|
| 928 | + |
---|
| 929 | + /* If there is space in the last packet queued, we copy the |
---|
| 930 | + * new packet in its buffer. |
---|
| 931 | + */ |
---|
| 932 | + if (pkt->len <= last_pkt->buf_len - last_pkt->len) { |
---|
| 933 | + memcpy(last_pkt->buf + last_pkt->len, pkt->buf, |
---|
| 934 | + pkt->len); |
---|
| 935 | + last_pkt->len += pkt->len; |
---|
| 936 | + free_pkt = true; |
---|
| 937 | + goto out; |
---|
| 938 | + } |
---|
| 939 | + } |
---|
| 940 | + |
---|
| 941 | + list_add_tail(&pkt->list, &vvs->rx_queue); |
---|
| 942 | + |
---|
| 943 | +out: |
---|
| 944 | + spin_unlock_bh(&vvs->rx_lock); |
---|
| 945 | + if (free_pkt) |
---|
| 946 | + virtio_transport_free_pkt(pkt); |
---|
| 947 | +} |
---|
| 948 | + |
---|
857 | 949 | static int |
---|
858 | 950 | virtio_transport_recv_connected(struct sock *sk, |
---|
859 | 951 | struct virtio_vsock_pkt *pkt) |
---|
860 | 952 | { |
---|
861 | 953 | struct vsock_sock *vsk = vsock_sk(sk); |
---|
862 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
---|
863 | 954 | int err = 0; |
---|
864 | 955 | |
---|
865 | 956 | switch (le16_to_cpu(pkt->hdr.op)) { |
---|
866 | 957 | case VIRTIO_VSOCK_OP_RW: |
---|
867 | | - pkt->len = le32_to_cpu(pkt->hdr.len); |
---|
868 | | - pkt->off = 0; |
---|
869 | | - |
---|
870 | | - spin_lock_bh(&vvs->rx_lock); |
---|
871 | | - virtio_transport_inc_rx_pkt(vvs, pkt); |
---|
872 | | - list_add_tail(&pkt->list, &vvs->rx_queue); |
---|
873 | | - spin_unlock_bh(&vvs->rx_lock); |
---|
874 | | - |
---|
| 958 | + virtio_transport_recv_enqueue(vsk, pkt); |
---|
875 | 959 | sk->sk_data_ready(sk); |
---|
876 | 960 | return err; |
---|
877 | 961 | case VIRTIO_VSOCK_OP_CREDIT_UPDATE: |
---|
.. | .. |
---|
930 | 1014 | return virtio_transport_send_pkt_info(vsk, &info); |
---|
931 | 1015 | } |
---|
932 | 1016 | |
---|
| 1017 | +static bool virtio_transport_space_update(struct sock *sk, |
---|
| 1018 | + struct virtio_vsock_pkt *pkt) |
---|
| 1019 | +{ |
---|
| 1020 | + struct vsock_sock *vsk = vsock_sk(sk); |
---|
| 1021 | + struct virtio_vsock_sock *vvs = vsk->trans; |
---|
| 1022 | + bool space_available; |
---|
| 1023 | + |
---|
| 1024 | + /* Listener sockets are not associated with any transport, so we are |
---|
| 1025 | + * not able to take the state to see if there is space available in the |
---|
| 1026 | + * remote peer, but since they are only used to receive requests, we |
---|
| 1027 | + * can assume that there is always space available in the other peer. |
---|
| 1028 | + */ |
---|
| 1029 | + if (!vvs) |
---|
| 1030 | + return true; |
---|
| 1031 | + |
---|
| 1032 | + /* buf_alloc and fwd_cnt is always included in the hdr */ |
---|
| 1033 | + spin_lock_bh(&vvs->tx_lock); |
---|
| 1034 | + vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); |
---|
| 1035 | + vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); |
---|
| 1036 | + space_available = virtio_transport_has_space(vsk); |
---|
| 1037 | + spin_unlock_bh(&vvs->tx_lock); |
---|
| 1038 | + return space_available; |
---|
| 1039 | +} |
---|
| 1040 | + |
---|
933 | 1041 | /* Handle server socket */ |
---|
934 | 1042 | static int |
---|
935 | | -virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) |
---|
| 1043 | +virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt, |
---|
| 1044 | + struct virtio_transport *t) |
---|
936 | 1045 | { |
---|
937 | 1046 | struct vsock_sock *vsk = vsock_sk(sk); |
---|
938 | 1047 | struct vsock_sock *vchild; |
---|
939 | 1048 | struct sock *child; |
---|
| 1049 | + int ret; |
---|
940 | 1050 | |
---|
941 | 1051 | if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { |
---|
942 | | - virtio_transport_reset(vsk, pkt); |
---|
| 1052 | + virtio_transport_reset_no_sock(t, pkt); |
---|
943 | 1053 | return -EINVAL; |
---|
944 | 1054 | } |
---|
945 | 1055 | |
---|
946 | 1056 | if (sk_acceptq_is_full(sk)) { |
---|
947 | | - virtio_transport_reset(vsk, pkt); |
---|
| 1057 | + virtio_transport_reset_no_sock(t, pkt); |
---|
948 | 1058 | return -ENOMEM; |
---|
949 | 1059 | } |
---|
950 | 1060 | |
---|
951 | | - child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, |
---|
952 | | - sk->sk_type, 0); |
---|
| 1061 | + child = vsock_create_connected(sk); |
---|
953 | 1062 | if (!child) { |
---|
954 | | - virtio_transport_reset(vsk, pkt); |
---|
| 1063 | + virtio_transport_reset_no_sock(t, pkt); |
---|
955 | 1064 | return -ENOMEM; |
---|
956 | 1065 | } |
---|
957 | 1066 | |
---|
958 | | - sk->sk_ack_backlog++; |
---|
| 1067 | + sk_acceptq_added(sk); |
---|
959 | 1068 | |
---|
960 | 1069 | lock_sock_nested(child, SINGLE_DEPTH_NESTING); |
---|
961 | 1070 | |
---|
.. | .. |
---|
967 | 1076 | vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), |
---|
968 | 1077 | le32_to_cpu(pkt->hdr.src_port)); |
---|
969 | 1078 | |
---|
| 1079 | + ret = vsock_assign_transport(vchild, vsk); |
---|
| 1080 | + /* Transport assigned (looking at remote_addr) must be the same |
---|
| 1081 | + * where we received the request. |
---|
| 1082 | + */ |
---|
| 1083 | + if (ret || vchild->transport != &t->transport) { |
---|
| 1084 | + release_sock(child); |
---|
| 1085 | + virtio_transport_reset_no_sock(t, pkt); |
---|
| 1086 | + sock_put(child); |
---|
| 1087 | + return ret; |
---|
| 1088 | + } |
---|
| 1089 | + |
---|
| 1090 | + if (virtio_transport_space_update(child, pkt)) |
---|
| 1091 | + child->sk_write_space(child); |
---|
| 1092 | + |
---|
970 | 1093 | vsock_insert_connected(vchild); |
---|
971 | 1094 | vsock_enqueue_accept(sk, child); |
---|
972 | 1095 | virtio_transport_send_response(vchild, pkt); |
---|
.. | .. |
---|
975 | 1098 | |
---|
976 | 1099 | sk->sk_data_ready(sk); |
---|
977 | 1100 | return 0; |
---|
978 | | -} |
---|
979 | | - |
---|
980 | | -static bool virtio_transport_space_update(struct sock *sk, |
---|
981 | | - struct virtio_vsock_pkt *pkt) |
---|
982 | | -{ |
---|
983 | | - struct vsock_sock *vsk = vsock_sk(sk); |
---|
984 | | - struct virtio_vsock_sock *vvs = vsk->trans; |
---|
985 | | - bool space_available; |
---|
986 | | - |
---|
987 | | - /* buf_alloc and fwd_cnt is always included in the hdr */ |
---|
988 | | - spin_lock_bh(&vvs->tx_lock); |
---|
989 | | - vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); |
---|
990 | | - vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); |
---|
991 | | - space_available = virtio_transport_has_space(vsk); |
---|
992 | | - spin_unlock_bh(&vvs->tx_lock); |
---|
993 | | - return space_available; |
---|
994 | 1101 | } |
---|
995 | 1102 | |
---|
996 | 1103 | /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex |
---|
.. | .. |
---|
1039 | 1146 | |
---|
1040 | 1147 | lock_sock(sk); |
---|
1041 | 1148 | |
---|
| 1149 | + /* Check if sk has been closed before lock_sock */ |
---|
| 1150 | + if (sock_flag(sk, SOCK_DONE)) { |
---|
| 1151 | + (void)virtio_transport_reset_no_sock(t, pkt); |
---|
| 1152 | + release_sock(sk); |
---|
| 1153 | + sock_put(sk); |
---|
| 1154 | + goto free_pkt; |
---|
| 1155 | + } |
---|
| 1156 | + |
---|
1042 | 1157 | space_available = virtio_transport_space_update(sk, pkt); |
---|
1043 | 1158 | |
---|
1044 | 1159 | /* Update CID in case it has changed after a transport reset event */ |
---|
1045 | | - vsk->local_addr.svm_cid = dst.svm_cid; |
---|
| 1160 | + if (vsk->local_addr.svm_cid != VMADDR_CID_ANY) |
---|
| 1161 | + vsk->local_addr.svm_cid = dst.svm_cid; |
---|
1046 | 1162 | |
---|
1047 | 1163 | if (space_available) |
---|
1048 | 1164 | sk->sk_write_space(sk); |
---|
1049 | 1165 | |
---|
1050 | 1166 | switch (sk->sk_state) { |
---|
1051 | 1167 | case TCP_LISTEN: |
---|
1052 | | - virtio_transport_recv_listen(sk, pkt); |
---|
| 1168 | + virtio_transport_recv_listen(sk, pkt, t); |
---|
1053 | 1169 | virtio_transport_free_pkt(pkt); |
---|
1054 | 1170 | break; |
---|
1055 | 1171 | case TCP_SYN_SENT: |
---|
.. | .. |
---|
1068 | 1184 | virtio_transport_free_pkt(pkt); |
---|
1069 | 1185 | break; |
---|
1070 | 1186 | } |
---|
| 1187 | + |
---|
1071 | 1188 | release_sock(sk); |
---|
1072 | 1189 | |
---|
1073 | 1190 | /* Release refcnt obtained when we fetched this socket out of the |
---|
.. | .. |
---|
1083 | 1200 | |
---|
1084 | 1201 | void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) |
---|
1085 | 1202 | { |
---|
1086 | | - kfree(pkt->buf); |
---|
| 1203 | + kvfree(pkt->buf); |
---|
1087 | 1204 | kfree(pkt); |
---|
1088 | 1205 | } |
---|
1089 | 1206 | EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); |
---|