| .. | .. |
|---|
| 1 | +// SPDX-License-Identifier: GPL-2.0-only |
|---|
| 1 | 2 | /* |
|---|
| 2 | 3 | * virtio transport for vsock |
|---|
| 3 | 4 | * |
|---|
| .. | .. |
|---|
| 7 | 8 | * |
|---|
| 8 | 9 | * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s |
|---|
| 9 | 10 | * early virtio-vsock proof-of-concept bits. |
|---|
| 10 | | - * |
|---|
| 11 | | - * This work is licensed under the terms of the GNU GPL, version 2. |
|---|
| 12 | 11 | */ |
|---|
| 13 | 12 | #include <linux/spinlock.h> |
|---|
| 14 | 13 | #include <linux/module.h> |
|---|
| .. | .. |
|---|
| 23 | 22 | #include <net/af_vsock.h> |
|---|
| 24 | 23 | |
|---|
| 25 | 24 | static struct workqueue_struct *virtio_vsock_workqueue; |
|---|
| 26 | | -static struct virtio_vsock *the_virtio_vsock; |
|---|
| 25 | +static struct virtio_vsock __rcu *the_virtio_vsock; |
|---|
| 27 | 26 | static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */ |
|---|
| 28 | 27 | |
|---|
| 29 | 28 | struct virtio_vsock { |
|---|
| .. | .. |
|---|
| 44 | 43 | struct work_struct send_pkt_work; |
|---|
| 45 | 44 | spinlock_t send_pkt_list_lock; |
|---|
| 46 | 45 | struct list_head send_pkt_list; |
|---|
| 47 | | - |
|---|
| 48 | | - struct work_struct loopback_work; |
|---|
| 49 | | - spinlock_t loopback_list_lock; /* protects loopback_list */ |
|---|
| 50 | | - struct list_head loopback_list; |
|---|
| 51 | 46 | |
|---|
| 52 | 47 | atomic_t queued_replies; |
|---|
| 53 | 48 | |
|---|
| .. | .. |
|---|
| 85 | 80 | out_rcu: |
|---|
| 86 | 81 | rcu_read_unlock(); |
|---|
| 87 | 82 | return ret; |
|---|
| 88 | | -} |
|---|
| 89 | | - |
|---|
| 90 | | -static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock, |
|---|
| 91 | | - struct virtio_vsock_pkt *pkt) |
|---|
| 92 | | -{ |
|---|
| 93 | | - int len = pkt->len; |
|---|
| 94 | | - |
|---|
| 95 | | - spin_lock_bh(&vsock->loopback_list_lock); |
|---|
| 96 | | - list_add_tail(&pkt->list, &vsock->loopback_list); |
|---|
| 97 | | - spin_unlock_bh(&vsock->loopback_list_lock); |
|---|
| 98 | | - |
|---|
| 99 | | - queue_work(virtio_vsock_workqueue, &vsock->loopback_work); |
|---|
| 100 | | - |
|---|
| 101 | | - return len; |
|---|
| 102 | 83 | } |
|---|
| 103 | 84 | |
|---|
| 104 | 85 | static void |
|---|
| .. | .. |
|---|
| 195 | 176 | } |
|---|
| 196 | 177 | |
|---|
| 197 | 178 | if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) { |
|---|
| 198 | | - len = virtio_transport_send_pkt_loopback(vsock, pkt); |
|---|
| 179 | + virtio_transport_free_pkt(pkt); |
|---|
| 180 | + len = -ENODEV; |
|---|
| 199 | 181 | goto out_rcu; |
|---|
| 200 | 182 | } |
|---|
| 201 | 183 | |
|---|
| .. | .. |
|---|
| 281 | 263 | break; |
|---|
| 282 | 264 | } |
|---|
| 283 | 265 | |
|---|
| 266 | + pkt->buf_len = buf_len; |
|---|
| 284 | 267 | pkt->len = buf_len; |
|---|
| 285 | 268 | |
|---|
| 286 | 269 | sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); |
|---|
| .. | .. |
|---|
| 465 | 448 | |
|---|
| 466 | 449 | static struct virtio_transport virtio_transport = { |
|---|
| 467 | 450 | .transport = { |
|---|
| 451 | + .module = THIS_MODULE, |
|---|
| 452 | + |
|---|
| 468 | 453 | .get_local_cid = virtio_transport_get_local_cid, |
|---|
| 469 | 454 | |
|---|
| 470 | 455 | .init = virtio_transport_do_socket_init, |
|---|
| .. | .. |
|---|
| 497 | 482 | .notify_send_pre_block = virtio_transport_notify_send_pre_block, |
|---|
| 498 | 483 | .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, |
|---|
| 499 | 484 | .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, |
|---|
| 500 | | - |
|---|
| 501 | | - .set_buffer_size = virtio_transport_set_buffer_size, |
|---|
| 502 | | - .set_min_buffer_size = virtio_transport_set_min_buffer_size, |
|---|
| 503 | | - .set_max_buffer_size = virtio_transport_set_max_buffer_size, |
|---|
| 504 | | - .get_buffer_size = virtio_transport_get_buffer_size, |
|---|
| 505 | | - .get_min_buffer_size = virtio_transport_get_min_buffer_size, |
|---|
| 506 | | - .get_max_buffer_size = virtio_transport_get_max_buffer_size, |
|---|
| 485 | + .notify_buffer_size = virtio_transport_notify_buffer_size, |
|---|
| 507 | 486 | }, |
|---|
| 508 | 487 | |
|---|
| 509 | 488 | .send_pkt = virtio_transport_send_pkt, |
|---|
| 510 | 489 | }; |
|---|
| 511 | | - |
|---|
| 512 | | -static void virtio_transport_loopback_work(struct work_struct *work) |
|---|
| 513 | | -{ |
|---|
| 514 | | - struct virtio_vsock *vsock = |
|---|
| 515 | | - container_of(work, struct virtio_vsock, loopback_work); |
|---|
| 516 | | - LIST_HEAD(pkts); |
|---|
| 517 | | - |
|---|
| 518 | | - spin_lock_bh(&vsock->loopback_list_lock); |
|---|
| 519 | | - list_splice_init(&vsock->loopback_list, &pkts); |
|---|
| 520 | | - spin_unlock_bh(&vsock->loopback_list_lock); |
|---|
| 521 | | - |
|---|
| 522 | | - mutex_lock(&vsock->rx_lock); |
|---|
| 523 | | - |
|---|
| 524 | | - if (!vsock->rx_run) |
|---|
| 525 | | - goto out; |
|---|
| 526 | | - |
|---|
| 527 | | - while (!list_empty(&pkts)) { |
|---|
| 528 | | - struct virtio_vsock_pkt *pkt; |
|---|
| 529 | | - |
|---|
| 530 | | - pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list); |
|---|
| 531 | | - list_del_init(&pkt->list); |
|---|
| 532 | | - |
|---|
| 533 | | - virtio_transport_recv_pkt(&virtio_transport, pkt); |
|---|
| 534 | | - } |
|---|
| 535 | | -out: |
|---|
| 536 | | - mutex_unlock(&vsock->rx_lock); |
|---|
| 537 | | -} |
|---|
| 538 | 490 | |
|---|
| 539 | 491 | static void virtio_transport_rx_work(struct work_struct *work) |
|---|
| 540 | 492 | { |
|---|
| .. | .. |
|---|
| 640 | 592 | mutex_init(&vsock->event_lock); |
|---|
| 641 | 593 | spin_lock_init(&vsock->send_pkt_list_lock); |
|---|
| 642 | 594 | INIT_LIST_HEAD(&vsock->send_pkt_list); |
|---|
| 643 | | - spin_lock_init(&vsock->loopback_list_lock); |
|---|
| 644 | | - INIT_LIST_HEAD(&vsock->loopback_list); |
|---|
| 645 | 595 | INIT_WORK(&vsock->rx_work, virtio_transport_rx_work); |
|---|
| 646 | 596 | INIT_WORK(&vsock->tx_work, virtio_transport_tx_work); |
|---|
| 647 | 597 | INIT_WORK(&vsock->event_work, virtio_transport_event_work); |
|---|
| 648 | 598 | INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work); |
|---|
| 649 | | - INIT_WORK(&vsock->loopback_work, virtio_transport_loopback_work); |
|---|
| 650 | 599 | |
|---|
| 651 | 600 | mutex_lock(&vsock->tx_lock); |
|---|
| 652 | 601 | vsock->tx_run = true; |
|---|
| .. | .. |
|---|
| 684 | 633 | vdev->priv = NULL; |
|---|
| 685 | 634 | rcu_assign_pointer(the_virtio_vsock, NULL); |
|---|
| 686 | 635 | synchronize_rcu(); |
|---|
| 687 | | - |
|---|
| 688 | | - flush_work(&vsock->loopback_work); |
|---|
| 689 | | - flush_work(&vsock->rx_work); |
|---|
| 690 | | - flush_work(&vsock->tx_work); |
|---|
| 691 | | - flush_work(&vsock->event_work); |
|---|
| 692 | | - flush_work(&vsock->send_pkt_work); |
|---|
| 693 | 636 | |
|---|
| 694 | 637 | /* Reset all connected sockets when the device disappear */ |
|---|
| 695 | 638 | vsock_for_each_connected_socket(virtio_vsock_reset_sock); |
|---|
| .. | .. |
|---|
| 733 | 676 | } |
|---|
| 734 | 677 | spin_unlock_bh(&vsock->send_pkt_list_lock); |
|---|
| 735 | 678 | |
|---|
| 736 | | - spin_lock_bh(&vsock->loopback_list_lock); |
|---|
| 737 | | - while (!list_empty(&vsock->loopback_list)) { |
|---|
| 738 | | - pkt = list_first_entry(&vsock->loopback_list, |
|---|
| 739 | | - struct virtio_vsock_pkt, list); |
|---|
| 740 | | - list_del(&pkt->list); |
|---|
| 741 | | - virtio_transport_free_pkt(pkt); |
|---|
| 742 | | - } |
|---|
| 743 | | - spin_unlock_bh(&vsock->loopback_list_lock); |
|---|
| 744 | | - |
|---|
| 745 | 679 | /* Delete virtqueues and flush outstanding callbacks if any */ |
|---|
| 746 | 680 | vdev->config->del_vqs(vdev); |
|---|
| 681 | + |
|---|
| 682 | + /* Other works can be queued before 'config->del_vqs()', so we flush |
|---|
| 683 | + * all works before to free the vsock object to avoid use after free. |
|---|
| 684 | + */ |
|---|
| 685 | + flush_work(&vsock->rx_work); |
|---|
| 686 | + flush_work(&vsock->tx_work); |
|---|
| 687 | + flush_work(&vsock->event_work); |
|---|
| 688 | + flush_work(&vsock->send_pkt_work); |
|---|
| 747 | 689 | |
|---|
| 748 | 690 | mutex_unlock(&the_virtio_vsock_mutex); |
|---|
| 749 | 691 | |
|---|
| .. | .. |
|---|
| 776 | 718 | if (!virtio_vsock_workqueue) |
|---|
| 777 | 719 | return -ENOMEM; |
|---|
| 778 | 720 | |
|---|
| 779 | | - ret = vsock_core_init(&virtio_transport.transport); |
|---|
| 721 | + ret = vsock_core_register(&virtio_transport.transport, |
|---|
| 722 | + VSOCK_TRANSPORT_F_G2H); |
|---|
| 780 | 723 | if (ret) |
|---|
| 781 | 724 | goto out_wq; |
|---|
| 782 | 725 | |
|---|
| .. | .. |
|---|
| 787 | 730 | return 0; |
|---|
| 788 | 731 | |
|---|
| 789 | 732 | out_vci: |
|---|
| 790 | | - vsock_core_exit(); |
|---|
| 733 | + vsock_core_unregister(&virtio_transport.transport); |
|---|
| 791 | 734 | out_wq: |
|---|
| 792 | 735 | destroy_workqueue(virtio_vsock_workqueue); |
|---|
| 793 | 736 | return ret; |
|---|
| .. | .. |
|---|
| 796 | 739 | static void __exit virtio_vsock_exit(void) |
|---|
| 797 | 740 | { |
|---|
| 798 | 741 | unregister_virtio_driver(&virtio_vsock_driver); |
|---|
| 799 | | - vsock_core_exit(); |
|---|
| 742 | + vsock_core_unregister(&virtio_transport.transport); |
|---|
| 800 | 743 | destroy_workqueue(virtio_vsock_workqueue); |
|---|
| 801 | 744 | } |
|---|
| 802 | 745 | |
|---|